LeakyRateReadout#

class brainpy.state.LeakyRateReadout(in_size, out_size, tau=Quantity(5., 'ms'), w_init=KaimingNormal(mode=fan_in, nonlinearity=relu, unit=1), name=None)#

Leaky dynamics for the read-out module.

This module implements a leaky integrator with the following dynamics:

\[r_{t} = \alpha r_{t-1} + x_{t} W\]

where:

  • \(r_{t}\) is the output at time t

  • \(\alpha = e^{-\Delta t / \tau}\) is the decay factor

  • \(x_{t}\) is the input at time t

  • \(W\) is the weight matrix

The leaky integrator acts as a low-pass filter, allowing the network to maintain memory of past inputs with an exponential decay determined by the time constant tau.

Parameters:
  • in_size (int or sequence of int) – Size of the input dimension(s)

  • out_size (int or sequence of int) – Size of the output dimension(s)

  • tau (ArrayLike, optional) – Time constant of the leaky dynamics, by default 5ms

  • w_init (Callable, optional) – Weight initialization function, by default KaimingNormal()

  • name (str, optional) – Name of the module, by default None

decay#

Decay factor computed as exp(-dt/tau)

Type:

float

weight#

Weight matrix connecting input to output

Type:

ParamState

r#

Hidden state representing the output values

Type:

HiddenState

Notes

  • The decay factor \(\alpha = e^{-\Delta t / \tau}\) is computed once at construction time using the current environment dt.

  • The weight matrix is initialized with Kaiming Normal initialization, which is suitable for layers that follow ReLU-like activations.

  • This module does not produce spikes; it outputs continuous rate values, making it suitable as the final readout layer in spiking networks trained with surrogate gradients.

References

Examples

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=1. * u.ms):
...     readout = brainpy.state.LeakyRateReadout(128, 10, tau=5.*u.ms)
...     readout.init_state(batch_size=1)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.