vmap_init_all_states

vmap_init_all_states#

class brainstate.nn.vmap_init_all_states(target, *init_args, axis_size=None, node_to_exclude=None, state_to_exclude=None, state_tag=None, in_states=None, out_states=None, **init_kwargs)#

Initialize states with vectorized mapping for creating batched module instances.

This function applies vmap to the initialization process, creating multiple batched instances of module states. Each batch element will have independent state values and random keys. This is useful for ensemble models or parameter sweeps.

Parameters:
Raises:

ValueError – If axis_size is None or not a positive integer.

Return type:

TypeVar(T, bound= Module)

Examples

>>> import brainstate
>>>
>>> net = brainstate.nn.Linear(10, 20)
>>> # Create 8 batched instances with different random initializations
>>> brainstate.nn.vmap_init_all_states(net, axis_size=8)
>>>
>>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10)
>>> print(net.weight.shape)

See also

init_all_states

Non-vectorized version.

vmap_new_states

The underlying vmap transformation for states.