brainstate.transform.shard_map#
- brainstate.transform.shard_map(fun, mesh, in_specs, out_specs, *, state_in_specs=None, state_out_specs=None, check_vma=True)#
Map a stateful function over shards of data across a device mesh (SPMD).
A state-aware wrapper over
jax.shard_map(). The function’sStateobjects are sharded or replicated acrossmeshaccording tostate_in_specs/state_out_specs(replicated by default), while positional arguments are sharded perin_specs. Inputs are placed on the mesh automatically (viajax.device_put()), so the wrapper works both eagerly and underjit().- Parameters:
fun (
Callable) – The function to shard. May read and writeStateobjects. Keyword arguments are broadcast (closed over), not sharded.mesh (
Mesh) – The device mesh, e.g.jax.make_mesh((4,), ('x',)).in_specs (
Any) – Sharding spec for each positional argument. A tuple must match the number of positional arguments; a single spec is applied to all.out_specs (
Any) – Sharding spec for the function’s output.state_in_specs (
P|Dict[State,P] |None) – Input sharding for states. A single spec applies to all states; a dict maps specific states. States not covered are replicated (PartitionSpec()).state_out_specs (
P|Dict[State,P] |None) – Output sharding for written states. Same conventions asstate_in_specs.check_vma (
bool) – Forwarded tojax.shard_map()(varying-manual-axes checking).
- Returns:
A function with the same positional signature as
funthat executes under SPMD sharding and applies state writes.- Return type:
See also
jax.shard_map,vmap,pmapNotes
jax.shard_maptracesfunat per-shard shapes, so the wrapper re-runsfun(rather than replaying a global jaxpr): it discovers the touched states once viaStatefulFunction, injects per-shard values withState.restore_value, runsfun, and restores every state afterward (writes to their new values, reads to their originals).Examples
>>> import brainstate >>> import jax, jax.numpy as jnp >>> from jax.sharding import PartitionSpec as P >>> >>> mesh = jax.make_mesh((jax.device_count(),), ('x',)) >>> w = brainstate.State(jnp.array(3.0)) >>> def fun(data): ... return data * w.value >>> f = brainstate.transform.shard_map(fun, mesh, in_specs=(P('x'),), out_specs=P('x')) >>> f(jnp.arange(jax.device_count() * 2, dtype=jnp.float32))
Keep a per-shard buffer by giving a state an explicit partition through
state_in_specs/state_out_specs; the buffer is read and written in place on each device:>>> import brainstate >>> import jax, jax.numpy as jnp >>> from jax.sharding import PartitionSpec as P >>> mesh = jax.make_mesh((jax.device_count(),), ('x',)) >>> buffer = brainstate.State(jnp.zeros(jax.device_count() * 2)) >>> def accumulate(data): ... buffer.value = buffer.value + data ... return data >>> f = brainstate.transform.shard_map( ... accumulate, mesh, in_specs=(P('x'),), out_specs=P('x'), ... state_in_specs={buffer: P('x')}, state_out_specs={buffer: P('x')}) >>> _ = f(jnp.ones(jax.device_count() * 2)) >>> buffer.value
Communicate across shards with collectives such as
jax.lax.psum(), referring to the mesh axis by name. Here each device contributes a partial sum andpsumreduces them to the global total (replicated back):>>> import brainstate >>> import jax, jax.numpy as jnp >>> from jax.sharding import PartitionSpec as P >>> mesh = jax.make_mesh((jax.device_count(),), ('x',)) >>> def global_sum(data): ... return jax.lax.psum(jnp.sum(data, keepdims=True), axis_name='x') >>> f = brainstate.transform.shard_map( ... global_sum, mesh, in_specs=(P('x'),), out_specs=P()) >>> f(jnp.arange(jax.device_count() * 2, dtype=jnp.float32))
shard_mapre-tracesfunon each call to discover its state usage; wrap it injax.jit()to amortise that on the hot path:>>> import brainstate >>> import jax, jax.numpy as jnp >>> from jax.sharding import PartitionSpec as P >>> mesh = jax.make_mesh((jax.device_count(),), ('x',)) >>> bias = brainstate.State(jnp.array(5.0)) >>> f = brainstate.transform.shard_map( ... lambda data: data + bias.value, mesh, ... in_specs=(P('x'),), out_specs=P('x')) >>> jit_f = jax.jit(f) >>> jit_f(jnp.arange(jax.device_count() * 2, dtype=jnp.float32))
Multi-axis meshes express combined data- and model-parallel shardings by naming each axis in the
PartitionSpec:>>> import jax, jax.numpy as jnp >>> import brainstate >>> from jax.sharding import PartitionSpec as P >>> n = jax.device_count() >>> mesh2d = jax.make_mesh((n // 2, 2), ('data', 'model')) # needs n >= 2 >>> f = brainstate.transform.shard_map( ... lambda data: data + 1.0, mesh2d, ... in_specs=(P('data', 'model'),), out_specs=P('data', 'model')) >>> f(jnp.ones((n // 2, 2)))