brainstate.transform.pmap2

Contents

brainstate.transform.pmap2#

brainstate.transform.pmap2(fn=<brainstate.typing.Missing object>, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, state_in_axes=None, state_out_axes=None, unexpected_out_state_mapping='raise')#

Parallel mapping with state-aware semantics across devices.

This function mirrors jax.pmap() but integrates with StatefulMapping so that State objects (including random states) are replicated and restored correctly on every device. When fn is omitted the function can be used as a decorator.

Parameters:
Returns:

If fn is provided, returns a StatefulMapping executing fn over devices. Otherwise returns a decorator that produces such an object.

Return type:

Callable[[TypeVar(F, bound= Callable)], TypeVar(F, bound= Callable)] | TypeVar(F, bound= Callable)

Raises:
  • ValueError – If axis_size or argument shapes are inconsistent.

  • BatchAxisError – If an unexpected state write occurs and the policy is 'raise'.

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>> from brainstate.util.filter import OfType
>>>
>>> weights = brainstate.ParamState(jnp.ones((4,)))
>>>
>>> @brainstate.transform.pmap2(
...     axis_name='devices',
...     in_axes=0,
...     out_axes=0,
...     state_in_axes={0: OfType(brainstate.ParamState)},
...     state_out_axes={0: OfType(brainstate.ParamState)},
... )
... def update(delta):
...     weights.value = weights.value + delta
...     return weights.value
>>>
>>> deltas = jnp.arange(jax.local_device_count() * 4.).reshape(
...     jax.local_device_count(), 4
... )
>>> updated = update(deltas)
>>> updated.shape
(jax.local_device_count(), 4)

See also

jax.pmap

Underlying JAX primitive.

vmap

Single-host vectorisation with the same state semantics.