LSTMCell#

class brainstate.nn.LSTMCell(num_in, num_out, w_init=XavierNormal(   scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit(   unit=Unit("1") ), state_init=ZeroInit(   unit=Unit("1") ), activation='tanh', name=None)#

Long Short-Term Memory (LSTM) cell implementation.

LSTM is a type of RNN architecture designed to address the vanishing gradient problem and learn long-term dependencies. It uses a cell state to carry information across time steps and three gates (input, forget, output) to control information flow.

The LSTM cell follows the mathematical formulation:

\[\begin{split}i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \\ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\ g_t &= \phi(W_g [x_t, h_{t-1}] + b_g) \\ o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \\ c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\ h_t &= o_t \odot \phi(c_t)\end{split}\]

where:

  • \(x_t\) is the input vector at time t

  • \(h_t\) is the hidden state at time t

  • \(c_t\) is the cell state at time t

  • \(i_t\) is the input gate activation

  • \(f_t\) is the forget gate activation

  • \(o_t\) is the output gate activation

  • \(g_t\) is the cell update (candidate) vector

  • \(\odot\) represents element-wise multiplication

  • \(\sigma\) is the sigmoid activation function

  • \(\phi\) is the activation function (typically tanh)

Parameters:
  • num_in (int) – The number of input units.

  • num_out (int) – The number of hidden/cell units.

  • w_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the weight matrices.

  • b_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the bias vectors.

  • state_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the hidden and cell states.

  • activation (str | Callable) – Activation function to use. Can be a string (e.g., ‘tanh’, ‘relu’) or a callable function.

  • name (str) – Name of the module.

num_in#

Number of input features.

Type:

int

num_out#

Number of hidden/cell units.

Type:

int

in_size#

Shape of input (num_in,).

Type:

tuple

out_size#

Shape of output (num_out,).

Type:

tuple

State Variables
---------------
h#

Hidden state of the LSTM cell.

Type:

HiddenState

c#

Cell state of the LSTM cell.

Type:

HiddenState

init_state(batch_size=None, \*\*kwargs)[source]#

Initialize the cell and hidden states.

reset_state(batch_size=None, \*\*kwargs)[source]#

Reset the cell and hidden states to their initial values.

update(x)[source]#

Update the states for one time step and return the new hidden state.

Examples

>>> import brainstate as bs
>>> import jax.numpy as jnp
>>>
>>> # Create an LSTM cell
>>> cell = bs.nn.LSTMCell(num_in=10, num_out=20)
>>>
>>> # Initialize states for batch size 32
>>> cell.init_state(batch_size=32)
>>>
>>> # Process a single time step
>>> x = jnp.ones((32, 10))  # batch_size x num_in
>>> output = cell.update(x)
>>> print(output.shape)  # (32, 20)
>>>
>>> # Process a sequence
>>> sequence = jnp.ones((100, 32, 10))  # time_steps x batch_size x num_in
>>> outputs = []
>>> for t in range(100):
...     output = cell.update(sequence[t])
...     outputs.append(output)
>>> outputs = jnp.stack(outputs)
>>> print(outputs.shape)  # (100, 32, 20)
>>>
>>> # Access cell state
>>> print(cell.c.value.shape)  # (32, 20)
>>> print(cell.h.value.shape)  # (32, 20)

Notes

  • The forget gate bias is initialized with +1.0 following Jozefowicz et al. (2015) to reduce forgetting at the beginning of training.

  • LSTM cells are effective for learning long-term dependencies but require more parameters and computation than simpler RNN variants.

References

init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.