brainstate.transform.vmap2#
- brainstate.transform.vmap2(fn=<brainstate.typing.Missing object>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, state_in_axes=None, state_out_axes=None, unexpected_out_state_mapping='raise')#
Vectorize a callable while preserving BrainState state semantics.
This helper mirrors
jax.vmap()but routes execution throughStatefulMappingso that reads and writes toStateinstances (including newly created random states) are tracked correctly across the mapped axis. The returned object can be used directly or as a decorator whenfnis omitted.- Parameters:
fn (
TypeVar(F, bound=Callable) |Missing) – Function to be vectorised. If omitted, the function acts as a decorator.in_axes (
int|Sequence[Any] |None) – Mapping specification for positional arguments, following the semantics ofjax.vmap().out_axes (
Any) – Placement of the mapped axis in the result. Must broadcast with the structure of the outputs.axis_name (
Hashable|None) – Name for the mapped axis so that collective primitives (e.g.lax.psum) can target it.axis_size (
int|None) – Explicit size of the mapped axis. If omitted, the size is inferred from the arguments.spmd_axis_name (
Hashable|Tuple[Hashable,...] |None) – Axis labels used when the transformed function is itself executed inside another SPMD transform (e.g. nestedvmap()orpmap()).state_in_axes (
Dict[Hashable,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]] |type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – Filters identifying whichStateobjects should be batched on input. Passing a single filter is shorthand for{0: filter}. Filters are converted withbrainstate.util.filter.to_predicate().state_out_axes (
Dict[Hashable,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]] |type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – Filters describing how written states are scattered back across the mapped axis. Semantics mirrorstate_in_axes.unexpected_out_state_mapping (
str) – Policy when a state is written during the mapped call but not matched bystate_out_axes.'raise'propagates aBatchAxisError,'warn'emits a warning, and'ignore'silently accepts the state.
- Returns:
If
fnis supplied, returns aStatefulMappinginstance that behaves likefnbut with batch semantics. Otherwise a decorator is returned.- Return type:
StatefulMapping|Callable[[TypeVar(F, bound=Callable)],StatefulMapping]- Raises:
ValueError – If axis sizes are inconsistent or cannot be inferred.
BatchAxisError – If a state write violates
state_out_axesand the policy is'raise'.
Examples
>>> import brainstate >>> import jax.numpy as jnp >>> from brainstate.util.filter import OfType >>> >>> counter = brainstate.ShortTermState(jnp.array(0.0)) >>> >>> @brainstate.transform.vmap2( ... in_axes=0, ... out_axes=0, ... state_in_axes={0: OfType(brainstate.ShortTermState)}, ... state_out_axes={0: OfType(brainstate.ShortTermState)}, ... ) ... def accumulate(x): ... counter.value = counter.value + x ... return counter.value >>> >>> xs = jnp.arange(3.0) >>> accumulate(xs) Array([0., 1., 3.], dtype=float32) >>> counter.value Array(3., dtype=float32)
See also
brainstate.transform.StatefulMappingUnderlying state-aware mapping helper.
pmapParallel mapping variant for multiple devices.
vmap_new_statesVectorize newly created states within
fn.