vmap_init_all_states#
- class brainstate.nn.vmap_init_all_states(target, *init_args, axis_size=None, node_to_exclude=None, state_to_exclude=None, state_tag=None, in_states=None, out_states=None, **init_kwargs)#
Initialize states with vectorized mapping for creating batched module instances.
This function applies vmap to the initialization process, creating multiple batched instances of module states. Each batch element will have independent state values and random keys. This is useful for ensemble models or parameter sweeps.
- Parameters:
target (
TypeVar(T, bound=Module)) – The target module whose states are to be initialized.*init_args – Variable positional arguments to pass to each init_state method.
axis_size (
int) – The size of the batch dimension. Must be a positive integer.node_to_exclude (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – A filter to exclude certain nodes from initialization.state_to_exclude (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – A filter to exclude certain states from being vmapped. Excluded states will remain shared across all batched instances.state_tag (
str|None) – An optional tag to categorize newly created states.**init_kwargs – Variable keyword arguments to pass to each init_state method.
- Raises:
ValueError – If axis_size is None or not a positive integer.
- Return type:
Examples
>>> import brainstate >>> >>> net = brainstate.nn.Linear(10, 20) >>> # Create 8 batched instances with different random initializations >>> brainstate.nn.vmap_init_all_states(net, axis_size=8) >>> >>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10) >>> print(net.weight.shape)
See also
init_all_statesNon-vectorized version.
vmap_new_statesThe underlying vmap transformation for states.