MGUCell#
- class braintrace.nn.MGUCell(in_size, out_size, w_init=Orthogonal(scale=1.0), b_init=ZeroInit(unit=1), state_init=ZeroInit(unit=1), activation='tanh', name=None)#
Minimal Gated Recurrent Unit (MGU) cell.
Minimal Gated Recurrent Unit (MGU) cell, implemented as in Minimal Gated Unit for Recurrent Neural Networks.
\[\begin{split}\begin{aligned} f_{t}&=\sigma (W_{f}x_{t}+U_{f}h_{t-1}+b_{f})\\ {\hat {h}}_{t}&=\phi (W_{h}x_{t}+U_{h}(f_{t}\odot h_{t-1})+b_{h})\\ h_{t}&=(1-f_{t})\odot h_{t-1}+f_{t}\odot {\hat {h}}_{t} \end{aligned}\end{split}\]where:
\(x_{t}\): input vector
\(h_{t}\): output vector
\({\hat {h}}_{t}\): candidate activation vector
\(f_{t}\): forget vector
\(W, U, b\): parameter matrices and vector
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of input units.out_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of hidden units.w_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The input weight initializer. Default is Orthogonal().b_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The bias weight initializer. Default is ZeroInit().state_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The state initializer. Default is ZeroInit().activation (
str|Callable) – The activation function. It can be a string or a callable function. Default is ‘tanh’.name (
str) – The name of the module. Default is None.
Examples
>>> import braintrace >>> import brainstate >>> >>> # Create an MGU cell >>> mgu_cell = braintrace.nn.MGUCell(in_size=96, out_size=192) >>> mgu_cell.init_state(batch_size=12) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(12, 96) >>> h = mgu_cell(x) >>> print(h.shape) (12, 192)