MGUCell#
- class brainstate.nn.MGUCell(num_in, num_out, w_init=Orthogonal( scale=1.0, axis=-1, rng=RandomState([ 900 9244]), unit=Unit("1") ), b_init=ZeroInit( unit=Unit("1") ), state_init=ZeroInit( unit=Unit("1") ), activation='tanh', name=None)#
Minimal Gated Unit (MGU) cell implementation.
MGU is a simplified version of GRU that uses a single forget gate instead of separate reset and update gates. This design significantly reduces the number of parameters while maintaining much of the gating capability. MGU provides a good trade-off between model complexity and performance.
The MGU cell follows the mathematical formulation:
\[\begin{split}f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\ \tilde{h}_t &= \phi(W_h [x_t, (f_t \odot h_{t-1})] + b_h) \\ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t\end{split}\]where:
\(x_t\) is the input vector at time t
\(h_t\) is the hidden state at time t
\(f_t\) is the forget gate vector
\(\tilde{h}_t\) is the candidate hidden state
\(\odot\) represents element-wise multiplication
\(\sigma\) is the sigmoid activation function
\(\phi\) is the activation function (typically tanh)
- Parameters:
num_in (
int) – The number of input units.num_out (
int) – The number of hidden units.w_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the weight matrices.b_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the bias vectors.state_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – Initializer for the hidden state.activation (
str|Callable) – Activation function to use. Can be a string (e.g., ‘tanh’, ‘relu’) or a callable function.name (
str) – Name of the module.
- State Variables
- ---------------
- h#
Hidden state of the MGU cell.
- Type:
Examples
>>> import brainstate as bs >>> import jax.numpy as jnp >>> >>> # Create an MGU cell >>> cell = bs.nn.MGUCell(num_in=10, num_out=20) >>> >>> # Initialize state for batch size 32 >>> cell.init_state(batch_size=32) >>> >>> # Process a single time step >>> x = jnp.ones((32, 10)) # batch_size x num_in >>> output = cell.update(x) >>> print(output.shape) # (32, 20) >>> >>> # Process a sequence >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in >>> outputs = [] >>> for t in range(100): ... output = cell.update(sequence[t]) ... outputs.append(output) >>> outputs = jnp.stack(outputs) >>> print(outputs.shape) # (100, 32, 20)
Notes
MGU provides a lightweight alternative to GRU and LSTM, making it suitable for resource-constrained applications or when model simplicity is preferred.
References