LRUCell#

class braintrace.nn.LRUCell(d_model, d_hidden, r_min=0.0, r_max=1.0, max_phase=6.28)[source]#

Linear Recurrent Unit (LRU) layer.

Linear Recurrent Unit (LRU) layer, which uses diagonal complex-valued state transitions for efficient sequence modeling.

\[\begin{split}h_{t+1} = \lambda * h_t + \exp(\gamma^{\mathrm{log}}) B x_{t+1} \\ \lambda = \text{diag}(\exp(-\exp(\nu^{\mathrm{log}}) + i \exp(\theta^\mathrm{log}))) \\ y_t = Re[C h_t + D x_t]\end{split}\]
Parameters:
  • d_model (int) – Input and output dimensions.

  • d_hidden (int) – Hidden state dimension.

  • r_min (float) – Smallest lambda norm. Default is 0.0.

  • r_max (float) – Largest lambda norm. Default is 1.0.

  • max_phase (float) – Max phase lambda. Default is 6.28.

Examples

>>> import braintrace
>>> import brainstate
>>>
>>> # Create an LRU cell
>>> lru_cell = braintrace.nn.LRUCell(d_model=64, d_hidden=128)
>>> lru_cell.init_state(batch_size=16)
>>>
>>> # Process a sequence of inputs
>>> x = brainstate.random.randn(16, 64)
>>> y = lru_cell(x)
>>> print(y.shape)
(16, 64)
init_state(batch_size=None, **kwargs)[source]#

State initialization function.

reset_state(batch_size=None, **kwargs)[source]#

State resetting function.