MGUCell#

class brainstate.nn.MGUCell(num_in, num_out, w_init=Orthogonal(   scale=1.0, axis=-1, rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit(   unit=Unit("1") ), state_init=ZeroInit(   unit=Unit("1") ), activation='tanh', name=None)#

Minimal Gated Unit (MGU) cell implementation.

MGU is a simplified version of GRU that uses a single forget gate instead of separate reset and update gates. This design significantly reduces the number of parameters while maintaining much of the gating capability. MGU provides a good trade-off between model complexity and performance.

The MGU cell follows the mathematical formulation:

\[\begin{split}f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\ \tilde{h}_t &= \phi(W_h [x_t, (f_t \odot h_{t-1})] + b_h) \\ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t\end{split}\]

where:

  • \(x_t\) is the input vector at time t

  • \(h_t\) is the hidden state at time t

  • \(f_t\) is the forget gate vector

  • \(\tilde{h}_t\) is the candidate hidden state

  • \(\odot\) represents element-wise multiplication

  • \(\sigma\) is the sigmoid activation function

  • \(\phi\) is the activation function (typically tanh)

Parameters:
  • num_in (int) – The number of input units.

  • num_out (int) – The number of hidden units.

  • w_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the weight matrices.

  • b_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the bias vectors.

  • state_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – Initializer for the hidden state.

  • activation (str | Callable) – Activation function to use. Can be a string (e.g., ‘tanh’, ‘relu’) or a callable function.

  • name (str) – Name of the module.

num_in#

Number of input features.

Type:

int

num_out#

Number of hidden units.

Type:

int

in_size#

Shape of input (num_in,).

Type:

tuple

out_size#

Shape of output (num_out,).

Type:

tuple

State Variables
---------------
h#

Hidden state of the MGU cell.

Type:

HiddenState

init_state(batch_size=None, \*\*kwargs)[source]#

Initialize the cell hidden state.

reset_state(batch_size=None, \*\*kwargs)[source]#

Reset the cell hidden state to its initial value.

update(x)[source]#

Update the hidden state for one time step and return the new state.

Examples

>>> import brainstate as bs
>>> import jax.numpy as jnp
>>>
>>> # Create an MGU cell
>>> cell = bs.nn.MGUCell(num_in=10, num_out=20)
>>>
>>> # Initialize state for batch size 32
>>> cell.init_state(batch_size=32)
>>>
>>> # Process a single time step
>>> x = jnp.ones((32, 10))  # batch_size x num_in
>>> output = cell.update(x)
>>> print(output.shape)  # (32, 20)
>>>
>>> # Process a sequence
>>> sequence = jnp.ones((100, 32, 10))  # time_steps x batch_size x num_in
>>> outputs = []
>>> for t in range(100):
...     output = cell.update(sequence[t])
...     outputs.append(output)
>>> outputs = jnp.stack(outputs)
>>> print(outputs.shape)  # (100, 32, 20)

Notes

MGU provides a lightweight alternative to GRU and LSTM, making it suitable for resource-constrained applications or when model simplicity is preferred.

References

init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.