vmap_call_all_fns#
- class brainstate.nn.vmap_call_all_fns(target, fn_name, args=(), kwargs=None, axis_size=None, node_to_exclude=None, state_tag=None, fn_if_not_exist='raise')[source]#
Apply vectorized mapping to call a function on all module nodes with batched state handling.
This function creates multiple batched instances by applying vmap to the specified method call across all module nodes. Each batch element maintains its own random key and state values. This is particularly useful for creating ensembles or batched models.
- Parameters:
target (
TypeVar(T, bound=Module)) – The target module on which to call functions.fn_name (
str) – The name of the method to call on each module node.args (
Sequence[Any] |Any) – Positional arguments to pass to the called method. A single non-tuple argument will be automatically wrapped in a tuple. Default is ().kwargs (
Mapping[str,Any] |None) – Keyword arguments to pass to the called method. Default is None.axis_size (
int) – The size of the batch dimension for vmap. Must be a positive integer.node_to_exclude (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – A filter to exclude certain nodes from the function call.state_tag (
str|None) – An optional tag to categorize newly created states during the vmap operation.fn_if_not_exist (
str) –Behavior when the specified method doesn’t exist on a node:
’raise’: Raise an AttributeError (default)
’pass’ or ‘none’: Skip the node silently
’warn’: Issue a warning and skip the node
- Raises:
ValueError – If axis_size is None or not a positive integer.
TypeError – If kwargs is not a mapping.
- Return type:
Examples
>>> import brainstate >>> >>> net = brainstate.nn.Linear(10, 20) >>> # Create 5 batched instances with different initializations >>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5)