StatefulMapping#

class brainstate.transform.StatefulMapping(fun, in_axes=0, out_axes=0, state_in_axes=None, state_out_axes=None, unexpected_out_state_mapping='raise', static_argnums=(), static_argnames=(), axis_env=None, return_only_write=True, axis_size=None, axis_name=None, name=None, mapping_fn=<function vmap>, mapping_kwargs=None)#

Vectorized wrapper that preserves BrainState state semantics during mapping.

StatefulMapping extends JAX mapping transforms (such as jax.vmap() and jax.pmap()) with awareness of State instances. It tracks state reads and writes across the mapped axis, ensures deterministic random-number handling, and restores side effects after each batched execution. The helper is typically constructed by brainstate.transform.vmap2() or brainstate.transform.pmap2(), but it can also be instantiated directly for custom mapping primitives.

Parameters:
origin_fun#

Original Python callable wrapped by the mapping helper.

Type:

callable

in_axes#

Mapping specification for positional arguments.

Type:

int, tuple of int, or None

out_axes#

Mapping specification for the return value.

Type:

int, tuple of int, or None

state_in_axes#

Normalized predicates describing which states to batch on input.

Type:

dict[AxisName, Predicate]

state_out_axes#

Normalized predicates describing which states to batch on output.

Type:

dict[AxisName, Predicate]

axis_size#

Size of the mapped axis, if explicitly provided.

Type:

int or None

axis_name#

Axis identifier forwarded to collective primitives.

Type:

hashable or None

mapping_fn#

Mapping primitive responsible for executing fun.

Type:

callable

Raises:
  • TypeError – If in_axes has an unsupported type.

  • ValueError – If batch dimensions are inconsistent or cannot be inferred.

  • RuntimeError – If tracing or executing the mapped function fails.

Notes

Random states (for example RandomState) encountered during execution are automatically split along the mapped axis and restored afterwards; this behaviour cannot be disabled. The wrapper caches inferred state placements, batch sizes, and trace stacks keyed by abstract argument signatures so repeated calls with the same structure avoid re-tracing.

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>> from brainstate.util.filter import OfType
>>>
>>> counter = brainstate.ShortTermState(jnp.array(0.0))
>>>
>>> def accumulate(x):
...     counter.value = counter.value + x
...     return counter.value
>>>
>>> batched_accumulate = brainstate.transform.StatefulMapping(
...     accumulate,
...     in_axes=0,
...     out_axes=0,
...     state_in_axes={0: OfType(brainstate.ShortTermState)},
...     state_out_axes={0: OfType(brainstate.ShortTermState)},
...     name="batched_accumulate",
... )
>>>
>>> xs = jnp.ones((3,))
>>> batched_accumulate(xs)
Array([1., 2., 3.], dtype=float32)
>>> counter.value
Array(3., dtype=float32)