MinimalRNNCell#
- class braintrace.nn.MinimalRNNCell(in_size, out_size, w_init=Orthogonal(scale=1.0), b_init=ZeroInit(unit=1), state_init=ZeroInit(unit=1), phi=None, name=None)#
Minimal RNN Cell.
Minimal RNN Cell, implemented as in MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks
At each step \(t\), the model first maps its input \(\mathbf{x}_t\) to a latent space through \(\mathbf{z}_t=\Phi(\mathbf{x}_t)\). \(\Phi(\cdot)\) here can be any highly flexible functions such as neural networks. By default, we take \(\Phi(\cdot)\) as a fully connected layer with tanh activation.
Given the latent representation \(\mathbf{z}_t\) of the input, MinimalRNN then updates its states simply as:
\[\mathbf{h}_t=\mathbf{u}_t\odot\mathbf{h}_{t-1}+(\mathbf{1}-\mathbf{u}_t)\odot\mathbf{z}_t\]where \(\mathbf{u}_t=\sigma(\mathbf{U}_h\mathbf{h}_{t-1}+\mathbf{U}_z\mathbf{z}_t+\mathbf{b}_u)\) is the update gate.
- Parameters:
in_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of input units.out_size (
int|Sequence[int] |integer|Sequence[integer]) – The number of hidden units.w_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The input weight initializer. Default is Orthogonal().b_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The bias weight initializer. Default is ZeroInit().state_init (
Array|ndarray|bool|number|bool|int|float|complex|Quantity|Callable) – The state initializer. Default is ZeroInit().phi (
Callable) – The input activation function. Default is None.name (
str) – The name of the module. Default is None.
Examples
>>> import braintrace >>> import brainstate >>> >>> # Create a Minimal RNN cell >>> minrnn_cell = braintrace.nn.MinimalRNNCell(in_size=100, out_size=200) >>> minrnn_cell.init_state(batch_size=24) >>> >>> # Process a sequence of inputs >>> x = brainstate.random.randn(24, 100) >>> h = minrnn_cell(x) >>> print(h.shape) (24, 200)