ValinaRNNCell#
- class brainstate.nn.ValinaRNNCell(num_in, num_out, state_init=ZeroInit( unit=Unit("1") ), w_init=XavierNormal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit( unit=Unit("1") ), activation='relu', name=None)#
Vanilla Recurrent Neural Network (RNN) cell implementation.
This class implements the basic RNN model that updates a hidden state based on the current input and previous hidden state. The standard RNN cell follows the mathematical formulation:
\[h_t = \phi(W [x_t, h_{t-1}] + b)\]where:
\(x_t\) is the input vector at time t
\(h_t\) is the hidden state at time t
\(h_{t-1}\) is the hidden state at previous time step
\(W\) is the weight matrix for the combined input-hidden linear transformation
\(b\) is the bias vector
\(\phi\) is the activation function
- Parameters:
num_in (
int) – The number of input units.num_out (
int) – The number of hidden units.state_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the hidden state.w_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the weight matrix.b_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the bias vector.activation (
str|Callable) – Activation function to use. Can be a string (e.g., ‘relu’, ‘tanh’) or a callable function.name (
str) – Name of the module.
- State Variables
- ---------------
- h#
Hidden state of the RNN cell.
- Type:
Examples
>>> import brainstate as bs >>> import jax.numpy as jnp >>> >>> # Create a vanilla RNN cell >>> cell = bs.nn.ValinaRNNCell(num_in=10, num_out=20) >>> >>> # Initialize state for batch size 32 >>> cell.init_state(batch_size=32) >>> >>> # Process a single time step >>> x = jnp.ones((32, 10)) # batch_size x num_in >>> output = cell.update(x) >>> print(output.shape) # (32, 20) >>> >>> # Process a sequence of inputs >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in >>> outputs = [] >>> for t in range(100): ... output = cell.update(sequence[t]) ... outputs.append(output) >>> outputs = jnp.stack(outputs) >>> print(outputs.shape) # (100, 32, 20)
Notes
Vanilla RNNs can suffer from vanishing or exploding gradient problems when processing long sequences. For better performance on long sequences, consider using gated architectures like GRU or LSTM.