brainstate.transform.shard_map

Contents

brainstate.transform.shard_map#

brainstate.transform.shard_map(fun, mesh, in_specs, out_specs, *, state_in_specs=None, state_out_specs=None, check_vma=True)#

Map a stateful function over shards of data across a device mesh (SPMD).

A state-aware wrapper over jax.shard_map(). The function’s State objects are sharded or replicated across mesh according to state_in_specs/state_out_specs (replicated by default), while positional arguments are sharded per in_specs. Inputs are placed on the mesh automatically (via jax.device_put()), so the wrapper works both eagerly and under jit().

Parameters:
  • fun (Callable) – The function to shard. May read and write State objects. Keyword arguments are broadcast (closed over), not sharded.

  • mesh (Mesh) – The device mesh, e.g. jax.make_mesh((4,), ('x',)).

  • in_specs (Any) – Sharding spec for each positional argument. A tuple must match the number of positional arguments; a single spec is applied to all.

  • out_specs (Any) – Sharding spec for the function’s output.

  • state_in_specs (P | Dict[State, P] | None) – Input sharding for states. A single spec applies to all states; a dict maps specific states. States not covered are replicated (PartitionSpec()).

  • state_out_specs (P | Dict[State, P] | None) – Output sharding for written states. Same conventions as state_in_specs.

  • check_vma (bool) – Forwarded to jax.shard_map() (varying-manual-axes checking).

Returns:

A function with the same positional signature as fun that executes under SPMD sharding and applies state writes.

Return type:

Callable

See also

jax.shard_map, vmap, pmap

Notes

jax.shard_map traces fun at per-shard shapes, so the wrapper re-runs fun (rather than replaying a global jaxpr): it discovers the touched states once via StatefulFunction, injects per-shard values with State.restore_value, runs fun, and restores every state afterward (writes to their new values, reads to their originals).

Examples

>>> import brainstate
>>> import jax, jax.numpy as jnp
>>> from jax.sharding import PartitionSpec as P
>>>
>>> mesh = jax.make_mesh((jax.device_count(),), ('x',))
>>> w = brainstate.State(jnp.array(3.0))
>>> def fun(data):
...     return data * w.value
>>> f = brainstate.transform.shard_map(fun, mesh, in_specs=(P('x'),), out_specs=P('x'))
>>> f(jnp.arange(jax.device_count() * 2, dtype=jnp.float32))

Keep a per-shard buffer by giving a state an explicit partition through state_in_specs / state_out_specs; the buffer is read and written in place on each device:

>>> import brainstate
>>> import jax, jax.numpy as jnp
>>> from jax.sharding import PartitionSpec as P
>>> mesh = jax.make_mesh((jax.device_count(),), ('x',))
>>> buffer = brainstate.State(jnp.zeros(jax.device_count() * 2))
>>> def accumulate(data):
...     buffer.value = buffer.value + data
...     return data
>>> f = brainstate.transform.shard_map(
...     accumulate, mesh, in_specs=(P('x'),), out_specs=P('x'),
...     state_in_specs={buffer: P('x')}, state_out_specs={buffer: P('x')})
>>> _ = f(jnp.ones(jax.device_count() * 2))
>>> buffer.value

Communicate across shards with collectives such as jax.lax.psum(), referring to the mesh axis by name. Here each device contributes a partial sum and psum reduces them to the global total (replicated back):

>>> import brainstate
>>> import jax, jax.numpy as jnp
>>> from jax.sharding import PartitionSpec as P
>>> mesh = jax.make_mesh((jax.device_count(),), ('x',))
>>> def global_sum(data):
...     return jax.lax.psum(jnp.sum(data, keepdims=True), axis_name='x')
>>> f = brainstate.transform.shard_map(
...     global_sum, mesh, in_specs=(P('x'),), out_specs=P())
>>> f(jnp.arange(jax.device_count() * 2, dtype=jnp.float32))

shard_map re-traces fun on each call to discover its state usage; wrap it in jax.jit() to amortise that on the hot path:

>>> import brainstate
>>> import jax, jax.numpy as jnp
>>> from jax.sharding import PartitionSpec as P
>>> mesh = jax.make_mesh((jax.device_count(),), ('x',))
>>> bias = brainstate.State(jnp.array(5.0))
>>> f = brainstate.transform.shard_map(
...     lambda data: data + bias.value, mesh,
...     in_specs=(P('x'),), out_specs=P('x'))
>>> jit_f = jax.jit(f)
>>> jit_f(jnp.arange(jax.device_count() * 2, dtype=jnp.float32))

Multi-axis meshes express combined data- and model-parallel shardings by naming each axis in the PartitionSpec:

>>> import jax, jax.numpy as jnp
>>> import brainstate
>>> from jax.sharding import PartitionSpec as P
>>> n = jax.device_count()
>>> mesh2d = jax.make_mesh((n // 2, 2), ('data', 'model'))  # needs n >= 2
>>> f = brainstate.transform.shard_map(
...     lambda data: data + 1.0, mesh2d,
...     in_specs=(P('data', 'model'),), out_specs=P('data', 'model'))
>>> f(jnp.ones((n // 2, 2)))