brainstate.transform.vmap_new_states#
- brainstate.transform.vmap_new_states(fun=<brainstate.typing.Missing object>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, state_tag=None, state_to_exclude=None, in_states=None, out_states=None)[source]#
Vectorize a function over new states created within it.
This function applies JAX’s vmap transformation to newly created states during the function’s execution. It allows for more flexible vectorization in the context of stateful computations.
- Parameters:
fun (
Callable) – The function to be vectorized. Defaults to Missing().in_axes (
int|None|Sequence[Any]) – Specification of input axes for vectorization. Defaults to 0.out_axes (
Any) – Specification of output axes after vectorization. Defaults to 0.axis_name (
Hashable|None) – Name of the axis being vectorized over. Defaults to None.axis_size (
int|None) – Size of the axis being vectorized over. Defaults to None.spmd_axis_name (
Hashable|tuple[Hashable,...] |None) – Name(s) of SPMD axis/axes. Defaults to None.state_tag (
str|None) – A tag to identify specific states. Defaults to None.state_to_exclude (
type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – Indices of states to exclude from vectorization. Defaults to ().
- Returns:
A vectorized version of the input function that handles new state creation.
- Return type:
Callable