MiniLSTM#

class braintrace.nn.MiniLSTM(in_size, out_size, w_init=Orthogonal(scale=1.0), b_init=ZeroInit(unit=1), state_init=ZeroInit(unit=1), name=None)#

Minimal LSTM cell.

Minimal LSTM Cell, a simplified version of LSTM implemented as in MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks

This simplified LSTM uses forget and input gates to control the flow of information, updating the hidden state as:

\[\mathbf{h}_t = \mathbf{f}_t \odot \mathbf{h}_{t-1} + \mathbf{i}_t \odot \mathbf{W}_x \mathbf{x}_t\]

where \(\mathbf{f}_t\) and \(\mathbf{i}_t\) are the forget and input gates, respectively.

Parameters:
  • in_size (int | Sequence[int] | integer | Sequence[integer]) – The number of input units.

  • out_size (int | Sequence[int] | integer | Sequence[integer]) – The number of hidden units.

  • w_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – The input weight initializer. Default is Orthogonal().

  • b_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – The bias weight initializer. Default is ZeroInit().

  • state_init (Array | ndarray | bool | number | bool | int | float | complex | Quantity | Callable) – The state initializer. Default is ZeroInit().

  • name (str) – The name of the module. Default is None.

Examples

>>> import braintrace
>>> import brainstate
>>>
>>> # Create a Mini LSTM cell
>>> minilstm_cell = braintrace.nn.MiniLSTM(in_size=150, out_size=300)
>>> minilstm_cell.init_state(batch_size=40)
>>>
>>> # Process a sequence of inputs
>>> x = brainstate.random.randn(40, 150)
>>> h = minilstm_cell(x)
>>> print(h.shape)
(40, 300)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.