Parallelisation#
brainstate.transform.pmap2 mirrors jax.pmap while keeping BrainState State
objects consistent across devices. This notebook explains how to configure the
API, how random states behave under device parallelism, and how pmap reuses
the same StatefulMapping infrastructure as vmap.
import jax
import jax.numpy as jnp
import brainstate
from brainstate.transform import pmap2
from brainstate.util.filter import OfType
Configuring devices#
For CPU-only demonstrations we can provision multiple devices per host by
setting jax_num_cpu_devices before importing JAX. (If you are running on GPU
or TPU you can skip this cell; the environment will report the hardware devices
it already sees.)
jax.config.update('jax_num_cpu_devices', 8)
print('local device count:', jax.local_device_count())
local device count: 8
1. Core arguments of pmap#
pmap accepts the same signature as jax.pmap plus BrainState-specific
keywords (state_in_axes, state_out_axes, unexpected_out_state_mapping).
Use axis_name to enable collectives and devices / backend when you want to
pin the computation to specific hardware.
class Affine(brainstate.nn.Module):
def __init__(self, size):
super().__init__()
self.weight = brainstate.ParamState(jnp.ones((size,)))
def __call__(self, delta):
self.weight.value = self.weight.value + delta
return self.weight.value
model = Affine(size=jax.local_device_count())
axis_name = 'devices'
pmapped_update = pmap2(
model,
axis_name=axis_name,
in_axes=0,
out_axes=0,
state_in_axes={0: OfType(brainstate.ParamState)},
state_out_axes={0: OfType(brainstate.ParamState)},
)
# Each device receives a different delta vector
per_device_delta = jnp.arange(jax.local_device_count() * 4.).reshape(jax.local_device_count(), 4)
updated = pmapped_update(per_device_delta)
print('updated shape:', updated.shape)
print('final weights:', model.weight.value)
updated shape: (8, 4)
final weights: [[ 1. 2. 3. 4.]
[ 5. 6. 7. 8.]
[ 9. 10. 11. 12.]
[13. 14. 15. 16.]
[17. 18. 19. 20.]
[21. 22. 23. 24.]
[25. 26. 27. 28.]
[29. 30. 31. 32.]]
axis_size and devices#
axis_size is inferred from the device list if possible. It is useful when you
want to simulate a smaller logical mesh than the number of physical devices.
devices lets you provide an explicit list of JAX devices to map over.
logical_devices = jax.devices()[:2]
model = Affine(size=len(logical_devices))
pairwise_update = pmap2(
model,
axis_name='pair',
in_axes=0,
out_axes=0,
devices=logical_devices,
state_in_axes={0: OfType(brainstate.ParamState)},
state_out_axes={0: OfType(brainstate.ParamState)},
)
deltas = jnp.stack([jnp.ones((4,)), -jnp.ones((4,))], axis=0)
pairwise_update(deltas)
print('weights after pairwise update:', model.weight.value)
weights after pairwise update: [[2. 2. 2. 2.]
[0. 0. 0. 0.]]
Handling static arguments and donation#
Most jax.pmap flags pass straight through: static_broadcasted_argnums keeps
an argument constant across devices, while donate_argnums can improve memory
usage by letting the compiler reuse buffers.
@pmap2(axis_name=axis_name, in_axes=(0, None), out_axes=0)
def add_with_scale(delta, scale):
return delta + scale
add_with_scale(jnp.arange(jax.local_device_count()), 0.5)
Array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5], dtype=float32, weak_type=True)
2. Random-number semantics#
As with vmap, BrainState splits RandomState keys automatically so that each
device sees a different stream. This makes stochastic simulations reproducible
without manual key management.
rand_state = brainstate.random.RandomState(0)
@pmap2(
axis_name='devices',
in_axes=0,
out_axes=0,
state_in_axes={0: OfType(brainstate.random.RandomState)},
state_out_axes={0: OfType(brainstate.random.RandomState)},
)
def sample_normal(scale):
return brainstate.random.normal(0.0, scale)
per_device_scales = jnp.linspace(1.0, 2.0, jax.local_device_count())
sample_normal(per_device_scales)
Array([ 1.23822 , -0.2782504 , -1.9162552 , -0.21000428, 0.41403982,
-0.7870412 , -1.6281602 , -1.1573448 ], dtype=float32)
If you need identical keys on all devices, use jax.random explicitly and mark
the key input as static (in_axes=None).
shared_key = jax.random.PRNGKey(0)
@pmap2(axis_name='devices', in_axes=(None, 0), out_axes=0)
def sample_shared(key, scale):
return jax.random.normal(key, ()) * scale
sample_shared(shared_key, per_device_scales)
Array([1.6226422, 1.8544483, 2.0862544, 2.3180604, 2.5498662, 2.7816722,
3.0134785, 3.2452843], dtype=float32)
3. Relationship to StatefulMapping#
pmap creates a StatefulMapping under the hood, just like vmap. The wrapper
analyzes state usage, constructs IR for the batched computation, and restores
state values after every parallel execution.
parallel_mapping = pmap2(
model,
axis_name='devices',
in_axes=0,
out_axes=0,
state_in_axes={0: OfType(brainstate.ParamState)},
state_out_axes={0: OfType(brainstate.ParamState)},
)
print(type(parallel_mapping))
print('origin fun:', parallel_mapping.origin_fun)
print('state_in_axes:', parallel_mapping.state_in_axes)
<class 'brainstate.transform.StatefulMapping'>
origin fun: Affine(
weight=ParamState(
value=ShapedArray(float32[2,4])
)
)
state_in_axes: {0: OfType(<class 'brainstate.ParamState'>)}
Advanced users can construct StatefulMapping directly, selecting their own
mapping primitive. Below we recreate the earlier example but pass an explicit
jax.pmap with custom donation settings.
from brainstate.transform import StatefulMapping
model = Affine(size=jax.local_device_count())
custom_pmap = StatefulMapping(
model,
in_axes=0,
out_axes=0,
state_in_axes={0: OfType(brainstate.ParamState)},
state_out_axes={0: OfType(brainstate.ParamState)},
axis_name='devices',
mapping_fn=lambda fun, *a, **kw: jax.pmap(fun, donate_argnums=(0,), *a, **kw),
)
custom_pmap(jnp.ones((jax.local_device_count(), 4)))
model.weight.value
Array([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]], dtype=float32)
Summary#
brainstate.transform.pmap2supports the fulljax.pmapinterface and adds state-specific controls viastate_in_axes,state_out_axes, andunexpected_out_state_mapping.Random states are split automatically so each device receives its own key. Use
jax.randomwithin_axes=Noneto broadcast a shared key instead.Like
vmap,pmapreturns aStatefulMappingthat identifies state axis mappings and compiles the computation into a state-aware IR.