LeakyRateReadout#

class braintrace.nn.LeakyRateReadout(in_size, out_size, tau=Quantity(5., 'ms'), w_init=KaimingNormal(mode=fan_in, nonlinearity=relu, unit=1), r_init=ZeroInit(unit=1), name=None)#

Leaky dynamics for the read-out module used in Real-Time Recurrent Learning.

The LeakyRateReadout class implements a leaky integration mechanism for processing continuous input signals in neural networks. It is designed to simulate the dynamics of rate-based neurons, applying leaky integration to the input and producing a continuous output signal.

This class is part of the BrainTrace project and integrates with the Brain Dynamics Programming ecosystem, providing a biologically inspired approach to neural computation.

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The size of the input to the readout module.

  • out_size (int | Sequence[int] | integer | Sequence[integer]) – The size of the output from the readout module.

  • tau (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – The time constant for the leaky integration dynamics. Default is 5 ms.

  • w_init (Callable) – A callable for initializing the weights of the readout module. Default is KaimingNormal().

  • r_init (Callable) – A callable for initializing the state of the readout module. Default is ZeroInit().

  • name (str | None) – An optional name for the module. Default is None.

in_size#

The size of the input.

Type:

tuple of int

out_size#

The size of the output.

Type:

tuple of int

tau#

The time constant for leaky integration.

Type:

ArrayLike

decay#

The decay factor computed from tau.

Type:

ArrayLike

r#

The readout state variable.

Type:

HiddenState

weight_op#

The parameter object that holds the weights and operations.

Type:

ParamState

Examples

>>> import braintrace
>>> import brainstate
>>> import saiunit as u
>>>
>>> # Create a leaky rate readout layer
>>> readout = braintrace.nn.LeakyRateReadout(
...     in_size=256,
...     out_size=10,
...     tau=5.0 * u.ms
... )
>>> readout.init_state(batch_size=32)
>>>
>>> # Process input through the readout layer
>>> x = brainstate.random.randn(32, 256)
>>> output = readout(x)
>>> print(output.shape)
(32, 10)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.