StatefulMapping#
- class brainstate.transform.StatefulMapping(fun, in_axes=0, out_axes=0, state_in_axes=None, state_out_axes=None, unexpected_out_state_mapping='raise', static_argnums=(), static_argnames=(), axis_env=None, return_only_write=True, axis_size=None, axis_name=None, name=None, mapping_fn=<function vmap>, mapping_kwargs=None)#
Vectorized wrapper that preserves BrainState state semantics during mapping.
StatefulMappingextends JAX mapping transforms (such asjax.vmap()andjax.pmap()) with awareness ofStateinstances. It tracks state reads and writes across the mapped axis, ensures deterministic random-number handling, and restores side effects after each batched execution. The helper is typically constructed bybrainstate.transform.vmap2()orbrainstate.transform.pmap2(), but it can also be instantiated directly for custom mapping primitives.- Parameters:
fun (
Callable) – Stateless callable to be wrapped. The callable may close overStateobjects that should be tracked during the mapping transform.in_axes (
int|Tuple[int,...] |None) – Alignment of the mapped axis per positional argument, following the semantics ofjax.vmap(). Arguments mapped withNoneare treated as static.out_axes (
int|Tuple[int,...] |None) – Placement of the mapped axis in the return value, consistent with JAX mapping primitives.state_in_axes (
Dict[Hashable,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]] |type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – Specification of input states that participate in the mapped axis. A dictionary maps axis identifiers tobrainstate.util.filterpredicates; passing a single filter applies it to axis0. Values are normalized viabrainstate.util.filter.to_predicate().state_out_axes (
Dict[Hashable,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]] |type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – Specification of state outputs to scatter back along the mapped axis. Uses the same semantics and normalization asstate_in_axes.unexpected_out_state_mapping (
str) – Strategy for handling states written during the mapped call that are not captured bystate_out_axes.axis_size (
int|None) – Explicit size of the mapped axis. When omitted, the size is inferred from the mapped arguments.axis_name (
Hashable|None) – Name for the mapped axis so that collective primitives can target it.name (
str|None) – Human-readable identifier for diagnostics and debugging.mapping_fn (
Callable) – Mapping primitive that executesfun. The callable must accept thein_axesandout_axeskeyword arguments used byjax.vmap().
- origin_fun#
Original Python callable wrapped by the mapping helper.
- Type:
callable
- state_in_axes#
Normalized predicates describing which states to batch on input.
- state_out_axes#
Normalized predicates describing which states to batch on output.
- mapping_fn#
Mapping primitive responsible for executing
fun.- Type:
callable
- Raises:
TypeError – If
in_axeshas an unsupported type.ValueError – If batch dimensions are inconsistent or cannot be inferred.
RuntimeError – If tracing or executing the mapped function fails.
Notes
Random states (for example
RandomState) encountered during execution are automatically split along the mapped axis and restored afterwards; this behaviour cannot be disabled. The wrapper caches inferred state placements, batch sizes, and trace stacks keyed by abstract argument signatures so repeated calls with the same structure avoid re-tracing.Examples
>>> import brainstate >>> import jax.numpy as jnp >>> from brainstate.util.filter import OfType >>> >>> counter = brainstate.ShortTermState(jnp.array(0.0)) >>> >>> def accumulate(x): ... counter.value = counter.value + x ... return counter.value >>> >>> batched_accumulate = brainstate.transform.StatefulMapping( ... accumulate, ... in_axes=0, ... out_axes=0, ... state_in_axes={0: OfType(brainstate.ShortTermState)}, ... state_out_axes={0: OfType(brainstate.ShortTermState)}, ... name="batched_accumulate", ... ) >>> >>> xs = jnp.ones((3,)) >>> batched_accumulate(xs) Array([1., 2., 3.], dtype=float32) >>> counter.value Array(3., dtype=float32)