brainstate.transform.vmap2

Contents

brainstate.transform.vmap2#

brainstate.transform.vmap2(fn=<brainstate.typing.Missing object>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, state_in_axes=None, state_out_axes=None, unexpected_out_state_mapping='raise')#

Vectorize a callable while preserving BrainState state semantics.

This helper mirrors jax.vmap() but routes execution through StatefulMapping so that reads and writes to State instances (including newly created random states) are tracked correctly across the mapped axis. The returned object can be used directly or as a decorator when fn is omitted.

Parameters:
Returns:

If fn is supplied, returns a StatefulMapping instance that behaves like fn but with batch semantics. Otherwise a decorator is returned.

Return type:

StatefulMapping | Callable[[TypeVar(F, bound= Callable)], StatefulMapping]

Raises:
  • ValueError – If axis sizes are inconsistent or cannot be inferred.

  • BatchAxisError – If a state write violates state_out_axes and the policy is 'raise'.

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>> from brainstate.util.filter import OfType
>>>
>>> counter = brainstate.ShortTermState(jnp.array(0.0))
>>>
>>> @brainstate.transform.vmap2(
...     in_axes=0,
...     out_axes=0,
...     state_in_axes={0: OfType(brainstate.ShortTermState)},
...     state_out_axes={0: OfType(brainstate.ShortTermState)},
... )
... def accumulate(x):
...     counter.value = counter.value + x
...     return counter.value
>>>
>>> xs = jnp.arange(3.0)
>>> accumulate(xs)
Array([0., 1., 3.], dtype=float32)
>>> counter.value
Array(3., dtype=float32)

See also

brainstate.transform.StatefulMapping

Underlying state-aware mapping helper.

pmap

Parallel mapping variant for multiple devices.

vmap_new_states

Vectorize newly created states within fn.