MiniGRU#
- class braintrace.nn.MiniGRU(in_size, out_size, w_init=Orthogonal(scale=1.0), b_init=ZeroInit(unit=1), state_init=ZeroInit(unit=1), name=None)#
Minimal GRU cell.
Minimal GRU Cell, a simplified version of GRU implemented as in MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks
At each step \(t\), the model processes the input through a gating mechanism that controls information flow. The hidden state is updated as:
\[\mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \mathbf{W}_x \mathbf{x}_t\]where \(\mathbf{z}_t=\sigma(\mathbf{W}_z[\mathbf{x}_t; \mathbf{h}_{t-1}])\) is the update gate.
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of input units.out_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of hidden units.w_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The input weight initializer. Default is Orthogonal().b_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The bias weight initializer. Default is ZeroInit().state_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The state initializer. Default is ZeroInit().name (
str) – The name of the module. Default is None.
Examples
>>> import braintrace >>> import brainstate >>> >>> # Create a Mini GRU cell >>> minigru_cell = braintrace.nn.MiniGRU(in_size=80, out_size=160) >>> minigru_cell.init_state(batch_size=32) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(32, 80) >>> h = minigru_cell(x) >>> print(h.shape) (32, 160)