RNNCell#

class brainstate.nn.RNNCell(name=None)#

Base class for all recurrent neural network (RNN) cell implementations.

This abstract class serves as the foundation for implementing various RNN cell types such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the Module class and provides common functionality and interface for recurrent units.

All RNN cell implementations should inherit from this class and implement the required methods, particularly the init_state(), reset_state(), and update() methods that define the state initialization and recurrent dynamics.

The RNNCell typically maintains hidden state(s) that are updated at each time step based on the current input and previous state values.

init_state(batch_size=None, \*\*kwargs)#

Initialize the cell state variables with appropriate dimensions.

reset_state(batch_size=None, \*\*kwargs)#

Reset the cell state variables to their initial values.

update(x)#

Update the cell state for one time step based on input x and return output.

See also

ValinaRNNCell

Vanilla RNN cell implementation

GRUCell

Gated Recurrent Unit cell implementation

LSTMCell

Long Short-Term Memory cell implementation

URLSTMCell

LSTM with UR gating mechanism

MGUCell

Minimal Gated Unit cell implementation