LeakyRateReadout#
- class braintrace.nn.LeakyRateReadout(in_size, out_size, tau=Quantity(5., 'ms'), w_init=KaimingNormal(mode=fan_in, nonlinearity=relu, unit=1), r_init=ZeroInit(unit=1), name=None)#
Leaky dynamics for the read-out module used in Real-Time Recurrent Learning.
The LeakyRateReadout class implements a leaky integration mechanism for processing continuous input signals in neural networks. It is designed to simulate the dynamics of rate-based neurons, applying leaky integration to the input and producing a continuous output signal.
This class is part of the BrainTrace project and integrates with the Brain Dynamics Programming ecosystem, providing a biologically inspired approach to neural computation.
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The size of the input to the readout module.out_size (
int|Sequence[int] |integer|Sequence[integer]) – The size of the output from the readout module.tau (
Array|ndarray|bool|number|bool|int|float|complex|Quantity) – The time constant for the leaky integration dynamics. Default is 5 ms.w_init (
Callable) – A callable for initializing the weights of the readout module. Default is KaimingNormal().r_init (
Callable) – A callable for initializing the state of the readout module. Default is ZeroInit().name (
str|None) – An optional name for the module. Default is None.
- tau#
The time constant for leaky integration.
- Type:
ArrayLike
- decay#
The decay factor computed from tau.
- Type:
ArrayLike
- r#
The readout state variable.
- Type:
HiddenState
- weight_op#
The parameter object that holds the weights and operations.
- Type:
ParamState
Examples
>>> import braintrace >>> import brainstate >>> import saiunit as u >>> >>> # Create a leaky rate readout layer >>> readout = braintrace.nn.LeakyRateReadout( ... in_size=256, ... out_size=10, ... tau=5.0 * u.ms ... ) >>> readout.init_state(batch_size=32) >>> >>> # Process input through the readout layer >>> x = brainstate.random.randn(32, 256) >>> output = readout(x) >>> print(output.shape) (32, 10)