ValinaRNNCell#

class brainstate.nn.ValinaRNNCell(num_in, num_out, state_init=ZeroInit(   unit=Unit("1") ), w_init=XavierNormal(   scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit(   unit=Unit("1") ), activation='relu', name=None)#

Vanilla Recurrent Neural Network (RNN) cell implementation.

This class implements the basic RNN model that updates a hidden state based on the current input and previous hidden state. The standard RNN cell follows the mathematical formulation:

\[h_t = \phi(W [x_t, h_{t-1}] + b)\]

where:

  • \(x_t\) is the input vector at time t

  • \(h_t\) is the hidden state at time t

  • \(h_{t-1}\) is the hidden state at previous time step

  • \(W\) is the weight matrix for the combined input-hidden linear transformation

  • \(b\) is the bias vector

  • \(\phi\) is the activation function

Parameters:
  • num_in (int) – The number of input units.

  • num_out (int) – The number of hidden units.

  • state_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the hidden state.

  • w_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the weight matrix.

  • b_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the bias vector.

  • activation (str | Callable) – Activation function to use. Can be a string (e.g., ‘relu’, ‘tanh’) or a callable function.

  • name (str) – Name of the module.

num_in#

Number of input features.

Type:

int

num_out#

Number of hidden units.

Type:

int

in_size#

Shape of input (num_in,).

Type:

tuple

out_size#

Shape of output (num_out,).

Type:

tuple

State Variables
---------------
h#

Hidden state of the RNN cell.

Type:

HiddenState

init_state(batch_size=None, \*\*kwargs)[source]#

Initialize the cell hidden state.

reset_state(batch_size=None, \*\*kwargs)[source]#

Reset the cell hidden state to its initial value.

update(x)[source]#

Update the hidden state for one time step and return the new state.

Examples

>>> import brainstate as bs
>>> import jax.numpy as jnp
>>>
>>> # Create a vanilla RNN cell
>>> cell = bs.nn.ValinaRNNCell(num_in=10, num_out=20)
>>>
>>> # Initialize state for batch size 32
>>> cell.init_state(batch_size=32)
>>>
>>> # Process a single time step
>>> x = jnp.ones((32, 10))  # batch_size x num_in
>>> output = cell.update(x)
>>> print(output.shape)  # (32, 20)
>>>
>>> # Process a sequence of inputs
>>> sequence = jnp.ones((100, 32, 10))  # time_steps x batch_size x num_in
>>> outputs = []
>>> for t in range(100):
...     output = cell.update(sequence[t])
...     outputs.append(output)
>>> outputs = jnp.stack(outputs)
>>> print(outputs.shape)  # (100, 32, 20)

Notes

Vanilla RNNs can suffer from vanishing or exploding gradient problems when processing long sequences. For better performance on long sequences, consider using gated architectures like GRU or LSTM.

init_state(batch_size=None, **kwargs)[source]#

Initialize the hidden state.

Parameters:
  • batch_size (int) – The batch size for state initialization.

  • **kwargs – Additional keyword arguments.

reset_state(batch_size=None, **kwargs)[source]#

Reset the hidden state to initial value.

Parameters:
  • batch_size (int) – The batch size for state reset.

  • **kwargs – Additional keyword arguments.