vmap_reset_all_states

vmap_reset_all_states#

class brainstate.nn.vmap_reset_all_states(target, *reset_args, axis_size=None, node_to_exclude=None, state_tag=None, **reset_kwargs)[source]#

Reset states with vectorized mapping across batched module instances.

This function applies vmap to the reset process, resetting states across all batched instances of the module. Each batch element will have its state reset independently with its own random key. This is useful when working with batched recurrent models or ensembles.

Parameters:
Raises:
  • ValueError – If axis_size is None or not a positive integer.

  • TypeError – If reset_kwargs is not a mapping.

Return type:

TypeVar(T, bound= Module)

Examples

>>> import brainstate
>>>
>>> rnn = brainstate.nn.RNNCell(10, 20)
>>> # Initialize with 16 batched instances
>>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16)
>>>
>>> # Process sequences...
>>>
>>> # Reset all 16 batched instances
>>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16)

See also

reset_all_states

Non-vectorized version.

vmap_call_all_functions

The underlying vmap function call mechanism.