URLSTMCell#

class brainstate.nn.URLSTMCell(num_in, num_out, w_init=XavierNormal(   scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), state_init=ZeroInit(   unit=Unit("1") ), activation='tanh', name=None)#

LSTM with UR gating mechanism.

URLSTM is a modification of the standard LSTM that uses untied (separate) biases for the forget and retention mechanisms, allowing for more flexible gating control. This implementation is based on the paper “Improving the Gating Mechanism of Recurrent Neural Networks” by Gers et al.

The URLSTM cell follows the mathematical formulation:

\[\begin{split}f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\ r_t &= \sigma(W_r [x_t, h_{t-1}] - b_f) \\ g_t &= 2 r_t \odot f_t + (1 - 2 r_t) \odot f_t^2 \\ \tilde{c}_t &= \phi(W_c [x_t, h_{t-1}]) \\ c_t &= g_t \odot c_{t-1} + (1 - g_t) \odot \tilde{c}_t \\ o_t &= \sigma(W_o [x_t, h_{t-1}]) \\ h_t &= o_t \odot \phi(c_t)\end{split}\]

where:

  • \(x_t\) is the input vector at time t

  • \(h_t\) is the hidden state at time t

  • \(c_t\) is the cell state at time t

  • \(f_t\) is the forget gate with positive bias

  • \(r_t\) is the retention gate with negative bias

  • \(g_t\) is the unified gate combining forget and retention

  • \(\tilde{c}_t\) is the candidate cell state

  • \(o_t\) is the output gate

  • \(\odot\) represents element-wise multiplication

  • \(\sigma\) is the sigmoid activation function

  • \(\phi\) is the activation function (typically tanh)

The key innovation is the untied bias mechanism where the forget and retention gates use opposite biases, initialized using a uniform distribution to encourage diverse gating behavior across units.

Parameters:
  • num_in (int) – The number of input units.

  • num_out (int) – The number of hidden/output units.

  • w_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the weight matrix.

  • state_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the hidden and cell states.

  • activation (str | Callable) – Activation function to use. Can be a string (e.g., ‘relu’, ‘tanh’) or a callable function.

  • name (str) – Name of the module.

  • Variables (State)

  • ---------------

  • h (HiddenState) – Hidden state of the URLSTM cell.

  • c (HiddenState) – Cell state of the URLSTM cell.

init_state(batch_size=None, \*\*kwargs)[source]#

Initialize the cell and hidden states.

reset_state(batch_size=None, \*\*kwargs)[source]#

Reset the cell and hidden states to their initial values.

update(x)[source]#

Update the cell and hidden states for one time step and return the hidden state.

Examples

>>> import brainstate as bs
>>> import jax.numpy as jnp
>>>
>>> # Create a URLSTM cell
>>> cell = bs.nn.URLSTMCell(num_in=10, num_out=20)
>>>
>>> # Initialize the state for batch size 32
>>> cell.init_state(batch_size=32)
>>>
>>> # Process a sequence
>>> x = jnp.ones((32, 10))  # batch_size x num_in
>>> output = cell.update(x)
>>> print(output.shape)  # (32, 20)
>>>
>>> # Process multiple time steps
>>> sequence = jnp.ones((100, 32, 10))  # time_steps x batch_size x num_in
>>> outputs = []
>>> for t in range(100):
...     output = cell.update(sequence[t])
...     outputs.append(output)
>>> outputs = jnp.stack(outputs)
>>> print(outputs.shape)  # (100, 32, 20)

References

init_state(batch_size=None, **kwargs)[source]#

Initialize the cell and hidden states.

Parameters:
  • batch_size (int) – The batch size for state initialization.

  • **kwargs – Additional keyword arguments.

reset_state(batch_size=None, **kwargs)[source]#

Reset the cell and hidden states to their initial values.

Parameters:
  • batch_size (int) – The batch size for state reset.

  • **kwargs – Additional keyword arguments.

update(x)[source]#

Update the URLSTM cell for one time step.

Parameters:

x (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – Input tensor with shape (batch_size, num_in).

Returns:

Hidden state tensor with shape (batch_size, num_out).

Return type:

Array | ndarray | bool | number | bool | int | float | complex | Quantity