GRUCell#
- class brainstate.nn.GRUCell(num_in, num_out, w_init=Orthogonal( scale=1.0, axis=-1, rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit( unit=Unit("1") ), state_init=ZeroInit( unit=Unit("1") ), activation='tanh', name=None)#
Gated Recurrent Unit (GRU) cell implementation.
The GRU is a gating mechanism in recurrent neural networks that aims to solve the vanishing gradient problem. It uses gating mechanisms to control information flow and has fewer parameters than LSTM as it combines the forget and input gates into a single update gate.
The GRU cell follows the mathematical formulation:
\[\begin{split}r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\ z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\ \tilde{h}_t &= \phi(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\end{split}\]where:
\(x_t\) is the input vector at time t
\(h_t\) is the hidden state at time t
\(r_t\) is the reset gate vector
\(z_t\) is the update gate vector
\(\tilde{h}_t\) is the candidate hidden state
\(\odot\) represents element-wise multiplication
\(\sigma\) is the sigmoid activation function
\(\phi\) is the activation function (typically tanh)
- Parameters:
num_in (
int) – The number of input units.num_out (
int) – The number of hidden units.w_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the weight matrices.b_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the bias vectors.state_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the hidden state.activation (
str|Callable) – Activation function to use. Can be a string (e.g., ‘tanh’, ‘relu’) or a callable function.name (
str) – Name of the module.
- State Variables
- ---------------
- h#
Hidden state of the GRU cell.
- Type:
Examples
>>> import brainstate as bs >>> import jax.numpy as jnp >>> >>> # Create a GRU cell >>> cell = bs.nn.GRUCell(num_in=10, num_out=20) >>> >>> # Initialize state for batch size 32 >>> cell.init_state(batch_size=32) >>> >>> # Process a single time step >>> x = jnp.ones((32, 10)) # batch_size x num_in >>> output = cell.update(x) >>> print(output.shape) # (32, 20) >>> >>> # Process a sequence >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in >>> outputs = [] >>> for t in range(100): ... output = cell.update(sequence[t]) ... outputs.append(output) >>> outputs = jnp.stack(outputs) >>> print(outputs.shape) # (100, 32, 20) >>> >>> # Reset state with different batch size >>> cell.reset_state(batch_size=16) >>> x_new = jnp.ones((16, 10)) >>> output_new = cell.update(x_new) >>> print(output_new.shape) # (16, 20)
Notes
GRU cells are computationally more efficient than LSTM cells due to having fewer parameters, while often achieving comparable performance on many tasks.
References