LSTMCell#
- class brainstate.nn.LSTMCell(num_in, num_out, w_init=XavierNormal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit( unit=Unit("1") ), state_init=ZeroInit( unit=Unit("1") ), activation='tanh', name=None)#
Long Short-Term Memory (LSTM) cell implementation.
LSTM is a type of RNN architecture designed to address the vanishing gradient problem and learn long-term dependencies. It uses a cell state to carry information across time steps and three gates (input, forget, output) to control information flow.
The LSTM cell follows the mathematical formulation:
\[\begin{split}i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \\ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\ g_t &= \phi(W_g [x_t, h_{t-1}] + b_g) \\ o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \\ c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\ h_t &= o_t \odot \phi(c_t)\end{split}\]where:
\(x_t\) is the input vector at time t
\(h_t\) is the hidden state at time t
\(c_t\) is the cell state at time t
\(i_t\) is the input gate activation
\(f_t\) is the forget gate activation
\(o_t\) is the output gate activation
\(g_t\) is the cell update (candidate) vector
\(\odot\) represents element-wise multiplication
\(\sigma\) is the sigmoid activation function
\(\phi\) is the activation function (typically tanh)
- Parameters:
num_in (
int) – The number of input units.num_out (
int) – The number of hidden/cell units.w_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the weight matrices.b_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the bias vectors.state_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the hidden and cell states.activation (
str|Callable) – Activation function to use. Can be a string (e.g., ‘tanh’, ‘relu’) or a callable function.name (
str) – Name of the module.
- State Variables
- ---------------
- h#
Hidden state of the LSTM cell.
- Type:
- c#
Cell state of the LSTM cell.
- Type:
- reset_state(batch_size=None, \*\*kwargs)[source]#
Reset the cell and hidden states to their initial values.
Examples
>>> import brainstate as bs >>> import jax.numpy as jnp >>> >>> # Create an LSTM cell >>> cell = bs.nn.LSTMCell(num_in=10, num_out=20) >>> >>> # Initialize states for batch size 32 >>> cell.init_state(batch_size=32) >>> >>> # Process a single time step >>> x = jnp.ones((32, 10)) # batch_size x num_in >>> output = cell.update(x) >>> print(output.shape) # (32, 20) >>> >>> # Process a sequence >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in >>> outputs = [] >>> for t in range(100): ... output = cell.update(sequence[t]) ... outputs.append(output) >>> outputs = jnp.stack(outputs) >>> print(outputs.shape) # (100, 32, 20) >>> >>> # Access cell state >>> print(cell.c.value.shape) # (32, 20) >>> print(cell.h.value.shape) # (32, 20)
Notes
The forget gate bias is initialized with +1.0 following Jozefowicz et al. (2015) to reduce forgetting at the beginning of training.
LSTM cells are effective for learning long-term dependencies but require more parameters and computation than simpler RNN variants.
References