brainstate.transform.pmap2#
- brainstate.transform.pmap2(fn=<brainstate.typing.Missing object>, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, state_in_axes=None, state_out_axes=None, unexpected_out_state_mapping='raise')#
Parallel mapping with state-aware semantics across devices.
This function mirrors
jax.pmap()but integrates withStatefulMappingso thatStateobjects (including random states) are replicated and restored correctly on every device. Whenfnis omitted the function can be used as a decorator.- Parameters:
fn (
Callable[[NestedDict,...],Any] |Missing) – Function to execute in SPMD style. If omitted, a decorator is returned.axis_name (
Hashable|None) – Name for the mapped axis used by collective primitives.in_axes (
Any) – Axis mapping for positional arguments, identical tojax.pmap().out_axes (
Any) – Placement of the mapped axis in the outputs.static_broadcasted_argnums (
int|Iterable[int]) – Indices of positional arguments to treat as compile-time constants.devices (
Sequence[Device] |None) – Explicit device list to map over. Must be identical on every host in multi-host setups.backend (
str|None) – Backend identifier ('cpu','gpu', or'tpu').axis_size (
int|None) – Size of the mapped axis. Defaults tolen(devices)or the local device count whendevicesisNone.donate_argnums (
int|Iterable[int]) – Positional arguments whose buffers may be donated to the computation.global_arg_shapes (
Tuple[Tuple[int,...],...] |None) – Shapes for globally distributed arguments (i.e. arguments not replicated across devices).state_in_axes (
Dict[Hashable,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]] |type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – Filters indicating which states should be treated as device-mapped inputs.state_out_axes (
Dict[Hashable,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]] |type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) – Filters describing how state writes are scattered back to devices.unexpected_out_state_mapping (
str) – Policy applied when a state write is not covered bystate_out_axes.
- Returns:
If
fnis provided, returns aStatefulMappingexecutingfnover devices. Otherwise returns a decorator that produces such an object.- Return type:
Callable[[TypeVar(F, bound=Callable)],TypeVar(F, bound=Callable)] |TypeVar(F, bound=Callable)- Raises:
ValueError – If
axis_sizeor argument shapes are inconsistent.BatchAxisError – If an unexpected state write occurs and the policy is
'raise'.
Examples
>>> import brainstate >>> import jax.numpy as jnp >>> from brainstate.util.filter import OfType >>> >>> weights = brainstate.ParamState(jnp.ones((4,))) >>> >>> @brainstate.transform.pmap2( ... axis_name='devices', ... in_axes=0, ... out_axes=0, ... state_in_axes={0: OfType(brainstate.ParamState)}, ... state_out_axes={0: OfType(brainstate.ParamState)}, ... ) ... def update(delta): ... weights.value = weights.value + delta ... return weights.value >>> >>> deltas = jnp.arange(jax.local_device_count() * 4.).reshape( ... jax.local_device_count(), 4 ... ) >>> updated = update(deltas) >>> updated.shape (jax.local_device_count(), 4)
See also
jax.pmapUnderlying JAX primitive.
vmapSingle-host vectorisation with the same state semantics.