vmap_reset_all_states#
- class brainstate.nn.vmap_reset_all_states(target, *reset_args, axis_size=None, node_to_exclude=None, state_tag=None, **reset_kwargs)[source]#
Reset states with vectorized mapping across batched module instances.
This function applies vmap to the reset process, resetting states across all batched instances of the module. Each batch element will have its state reset independently with its own random key. This is useful when working with batched recurrent models or ensembles.
- Parameters:
target (
TypeVar(T, bound=Module)) – The target module whose states are to be reset.reset_args – Positional arguments to pass to each reset_state method. A single non-tuple argument will be automatically wrapped in a tuple. Default is ().
reset_kwargs – Keyword arguments to pass to each reset_state method. Default is None.
axis_size (
int) – The size of the batch dimension. Must be a positive integer.node_to_exclude (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – A filter to exclude certain nodes from reset.state_tag (
str|None) – An optional tag to categorize newly created states during the reset.
- Raises:
ValueError – If axis_size is None or not a positive integer.
TypeError – If reset_kwargs is not a mapping.
- Return type:
Examples
>>> import brainstate >>> >>> rnn = brainstate.nn.RNNCell(10, 20) >>> # Initialize with 16 batched instances >>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16) >>> >>> # Process sequences... >>> >>> # Reset all 16 batched instances >>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)
See also
reset_all_statesNon-vectorized version.
vmap_call_all_functionsThe underlying vmap function call mechanism.