ValinaRNNCell#

class braintrace.nn.ValinaRNNCell(in_size, out_size, state_init=ZeroInit(unit=1), w_init=XavierNormal(scale=1.0, unit=1), b_init=ZeroInit(unit=1), activation='relu', name=None)#

Vanilla RNN cell.

A basic recurrent neural network cell that applies a simple recurrent transformation to the input and previous hidden state.

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The number of input units.

  • out_size (int | Sequence[int] | integer | Sequence[integer]) – The number of hidden units.

  • state_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – The state initializer. Default is ZeroInit().

  • w_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – The input weight initializer. Default is XavierNormal().

  • b_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – The bias weight initializer. Default is ZeroInit().

  • activation (str | Callable) – The activation function. It can be a string or a callable function. Default is ‘relu’.

  • name (str) – The name of the module. Default is None.

Examples

>>> import braintrace
>>> import brainstate
>>>
>>> # Create a Vanilla RNN cell
>>> rnn_cell = braintrace.nn.ValinaRNNCell(in_size=32, out_size=64)
>>> rnn_cell.init_state(batch_size=8)
>>>
>>> # Process a sequence of inputs
>>> x = brainstate.random.randn(8, 32)
>>> h = rnn_cell(x)
>>> print(h.shape)
(8, 64)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.