URLSTMCell#
- class brainstate.nn.URLSTMCell(num_in, num_out, w_init=XavierNormal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, distribution='truncated_normal', rng=RandomState([ 900 9244]), unit=Unit("1") ), state_init=ZeroInit( unit=Unit("1") ), activation='tanh', name=None)#
LSTM with UR gating mechanism.
URLSTM is a modification of the standard LSTM that uses untied (separate) biases for the forget and retention mechanisms, allowing for more flexible gating control. This implementation is based on the paper “Improving the Gating Mechanism of Recurrent Neural Networks” by Gers et al.
The URLSTM cell follows the mathematical formulation:
\[\begin{split}f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\ r_t &= \sigma(W_r [x_t, h_{t-1}] - b_f) \\ g_t &= 2 r_t \odot f_t + (1 - 2 r_t) \odot f_t^2 \\ \tilde{c}_t &= \phi(W_c [x_t, h_{t-1}]) \\ c_t &= g_t \odot c_{t-1} + (1 - g_t) \odot \tilde{c}_t \\ o_t &= \sigma(W_o [x_t, h_{t-1}]) \\ h_t &= o_t \odot \phi(c_t)\end{split}\]where:
\(x_t\) is the input vector at time t
\(h_t\) is the hidden state at time t
\(c_t\) is the cell state at time t
\(f_t\) is the forget gate with positive bias
\(r_t\) is the retention gate with negative bias
\(g_t\) is the unified gate combining forget and retention
\(\tilde{c}_t\) is the candidate cell state
\(o_t\) is the output gate
\(\odot\) represents element-wise multiplication
\(\sigma\) is the sigmoid activation function
\(\phi\) is the activation function (typically tanh)
The key innovation is the untied bias mechanism where the forget and retention gates use opposite biases, initialized using a uniform distribution to encourage diverse gating behavior across units.
- Parameters:
num_in (
int) – The number of input units.num_out (
int) – The number of hidden/output units.w_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the weight matrix.state_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the hidden and cell states.activation (
str|Callable) – Activation function to use. Can be a string (e.g., ‘relu’, ‘tanh’) or a callable function.name (
str) – Name of the module.Variables (State)
---------------
h (HiddenState) – Hidden state of the URLSTM cell.
c (HiddenState) – Cell state of the URLSTM cell.
- reset_state(batch_size=None, \*\*kwargs)[source]#
Reset the cell and hidden states to their initial values.
Examples
>>> import brainstate as bs >>> import jax.numpy as jnp >>> >>> # Create a URLSTM cell >>> cell = bs.nn.URLSTMCell(num_in=10, num_out=20) >>> >>> # Initialize the state for batch size 32 >>> cell.init_state(batch_size=32) >>> >>> # Process a sequence >>> x = jnp.ones((32, 10)) # batch_size x num_in >>> output = cell.update(x) >>> print(output.shape) # (32, 20) >>> >>> # Process multiple time steps >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in >>> outputs = [] >>> for t in range(100): ... output = cell.update(sequence[t]) ... outputs.append(output) >>> outputs = jnp.stack(outputs) >>> print(outputs.shape) # (100, 32, 20)
References
- init_state(batch_size=None, **kwargs)[source]#
Initialize the cell and hidden states.
- Parameters:
batch_size (
int) – The batch size for state initialization.**kwargs – Additional keyword arguments.
- reset_state(batch_size=None, **kwargs)[source]#
Reset the cell and hidden states to their initial values.
- Parameters:
batch_size (
int) – The batch size for state reset.**kwargs – Additional keyword arguments.