GRUCell#

class brainstate.nn.GRUCell(num_in, num_out, w_init=Orthogonal(   scale=1.0, axis=-1, rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit(   unit=Unit("1") ), state_init=ZeroInit(   unit=Unit("1") ), activation='tanh', name=None)#

Gated Recurrent Unit (GRU) cell implementation.

The GRU is a gating mechanism in recurrent neural networks that aims to solve the vanishing gradient problem. It uses gating mechanisms to control information flow and has fewer parameters than LSTM as it combines the forget and input gates into a single update gate.

The GRU cell follows the mathematical formulation:

\[\begin{split}r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\ z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\ \tilde{h}_t &= \phi(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\end{split}\]

where:

  • \(x_t\) is the input vector at time t

  • \(h_t\) is the hidden state at time t

  • \(r_t\) is the reset gate vector

  • \(z_t\) is the update gate vector

  • \(\tilde{h}_t\) is the candidate hidden state

  • \(\odot\) represents element-wise multiplication

  • \(\sigma\) is the sigmoid activation function

  • \(\phi\) is the activation function (typically tanh)

Parameters:
  • num_in (int) – The number of input units.

  • num_out (int) – The number of hidden units.

  • w_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the weight matrices.

  • b_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the bias vectors.

  • state_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the hidden state.

  • activation (str | Callable) – Activation function to use. Can be a string (e.g., ‘tanh’, ‘relu’) or a callable function.

  • name (str) – Name of the module.

num_in#

Number of input features.

Type:

int

num_out#

Number of hidden units.

Type:

int

in_size#

Shape of input (num_in,).

Type:

tuple

out_size#

Shape of output (num_out,).

Type:

tuple

State Variables
---------------
h#

Hidden state of the GRU cell.

Type:

HiddenState

init_state(batch_size=None, \*\*kwargs)[source]#

Initialize the cell hidden state.

reset_state(batch_size=None, \*\*kwargs)[source]#

Reset the cell hidden state to its initial value.

update(x)[source]#

Update the hidden state for one time step and return the new state.

Examples

>>> import brainstate as bs
>>> import jax.numpy as jnp
>>>
>>> # Create a GRU cell
>>> cell = bs.nn.GRUCell(num_in=10, num_out=20)
>>>
>>> # Initialize state for batch size 32
>>> cell.init_state(batch_size=32)
>>>
>>> # Process a single time step
>>> x = jnp.ones((32, 10))  # batch_size x num_in
>>> output = cell.update(x)
>>> print(output.shape)  # (32, 20)
>>>
>>> # Process a sequence
>>> sequence = jnp.ones((100, 32, 10))  # time_steps x batch_size x num_in
>>> outputs = []
>>> for t in range(100):
...     output = cell.update(sequence[t])
...     outputs.append(output)
>>> outputs = jnp.stack(outputs)
>>> print(outputs.shape)  # (100, 32, 20)
>>>
>>> # Reset state with different batch size
>>> cell.reset_state(batch_size=16)
>>> x_new = jnp.ones((16, 10))
>>> output_new = cell.update(x_new)
>>> print(output_new.shape)  # (16, 20)

Notes

GRU cells are computationally more efficient than LSTM cells due to having fewer parameters, while often achieving comparable performance on many tasks.

References

init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.