reset_all_states#
- class brainstate.nn.reset_all_states(target, *reset_args, node_to_exclude=None, **reset_kwargs)#
Reset states for all module nodes within the target.
This is a convenience wrapper around call_all_functions that specifically calls the reset_state method on all module nodes. The execution order respects any @call_order() decorators on the reset_state methods. This is typically used to reset recurrent neural network states between sequences.
- Parameters:
target (
TypeVar(T, bound=Module)) – The target module whose states are to be reset.reset_args – Positional arguments to pass to each reset_state method. A single non-tuple argument will be automatically wrapped in a tuple. Default is ().
reset_kwargs – Keyword arguments to pass to each reset_state method. Default is None.
node_to_exclude (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – A filter to exclude certain nodes from reset. Can be a type, predicate function, or any filter supported by the graph API.
- Return type:
Examples
>>> import brainstate >>> >>> rnn = brainstate.nn.RNNCell(10, 20) >>> brainstate.nn.init_all_states(rnn, batch_size=32) >>> >>> # Process a sequence >>> for x in sequence: ... output = rnn(x) >>> >>> # Reset states before processing next sequence >>> brainstate.nn.reset_all_states(rnn)
See also
call_all_functionsThe underlying function that executes the calls.
vmap_reset_all_statesVectorized version for batched reset.