reset_all_states

Contents

reset_all_states#

class brainstate.nn.reset_all_states(target, *reset_args, node_to_exclude=None, **reset_kwargs)#

Reset states for all module nodes within the target.

This is a convenience wrapper around call_all_functions that specifically calls the reset_state method on all module nodes. The execution order respects any @call_order() decorators on the reset_state methods. This is typically used to reset recurrent neural network states between sequences.

Parameters:
Return type:

TypeVar(T, bound= Module)

Examples

>>> import brainstate
>>>
>>> rnn = brainstate.nn.RNNCell(10, 20)
>>> brainstate.nn.init_all_states(rnn, batch_size=32)
>>>
>>> # Process a sequence
>>> for x in sequence:
...     output = rnn(x)
>>>
>>> # Reset states before processing next sequence
>>> brainstate.nn.reset_all_states(rnn)

See also

call_all_functions

The underlying function that executes the calls.

vmap_reset_all_states

Vectorized version for batched reset.