Parallelisation#

brainstate.transform.pmap2 mirrors jax.pmap while keeping BrainState State objects consistent across devices. This notebook explains how to configure the API, how random states behave under device parallelism, and how pmap reuses the same StatefulMapping infrastructure as vmap.

import jax
import jax.numpy as jnp

import brainstate
from brainstate.transform import pmap2
from brainstate.util.filter import OfType

Configuring devices#

For CPU-only demonstrations we can provision multiple devices per host by setting jax_num_cpu_devices before importing JAX. (If you are running on GPU or TPU you can skip this cell; the environment will report the hardware devices it already sees.)

jax.config.update('jax_num_cpu_devices', 8)
print('local device count:', jax.local_device_count())
local device count: 8

1. Core arguments of pmap#

pmap accepts the same signature as jax.pmap plus BrainState-specific keywords (state_in_axes, state_out_axes, unexpected_out_state_mapping). Use axis_name to enable collectives and devices / backend when you want to pin the computation to specific hardware.

class Affine(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.weight = brainstate.ParamState(jnp.ones((size,)))

    def __call__(self, delta):
        self.weight.value = self.weight.value + delta
        return self.weight.value


model = Affine(size=jax.local_device_count())
axis_name = 'devices'

pmapped_update = pmap2(
    model,
    axis_name=axis_name,
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ParamState)},
    state_out_axes={0: OfType(brainstate.ParamState)},
)

# Each device receives a different delta vector
per_device_delta = jnp.arange(jax.local_device_count() * 4.).reshape(jax.local_device_count(), 4)
updated = pmapped_update(per_device_delta)
print('updated shape:', updated.shape)
print('final weights:', model.weight.value)
updated shape: (8, 4)
final weights: [[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 9. 10. 11. 12.]
 [13. 14. 15. 16.]
 [17. 18. 19. 20.]
 [21. 22. 23. 24.]
 [25. 26. 27. 28.]
 [29. 30. 31. 32.]]

axis_size and devices#

axis_size is inferred from the device list if possible. It is useful when you want to simulate a smaller logical mesh than the number of physical devices. devices lets you provide an explicit list of JAX devices to map over.

logical_devices = jax.devices()[:2]
model = Affine(size=len(logical_devices))

pairwise_update = pmap2(
    model,
    axis_name='pair',
    in_axes=0,
    out_axes=0,
    devices=logical_devices,
    state_in_axes={0: OfType(brainstate.ParamState)},
    state_out_axes={0: OfType(brainstate.ParamState)},
)

deltas = jnp.stack([jnp.ones((4,)), -jnp.ones((4,))], axis=0)
pairwise_update(deltas)
print('weights after pairwise update:', model.weight.value)
weights after pairwise update: [[2. 2. 2. 2.]
 [0. 0. 0. 0.]]

Handling static arguments and donation#

Most jax.pmap flags pass straight through: static_broadcasted_argnums keeps an argument constant across devices, while donate_argnums can improve memory usage by letting the compiler reuse buffers.

@pmap2(axis_name=axis_name, in_axes=(0, None), out_axes=0)
def add_with_scale(delta, scale):
    return delta + scale

add_with_scale(jnp.arange(jax.local_device_count()), 0.5)
Array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5], dtype=float32, weak_type=True)

2. Random-number semantics#

As with vmap, BrainState splits RandomState keys automatically so that each device sees a different stream. This makes stochastic simulations reproducible without manual key management.

rand_state = brainstate.random.RandomState(0)

@pmap2(
    axis_name='devices',
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.random.RandomState)},
    state_out_axes={0: OfType(brainstate.random.RandomState)},
)
def sample_normal(scale):
    return brainstate.random.normal(0.0, scale)

per_device_scales = jnp.linspace(1.0, 2.0, jax.local_device_count())
sample_normal(per_device_scales)
Array([ 1.23822   , -0.2782504 , -1.9162552 , -0.21000428,  0.41403982,
       -0.7870412 , -1.6281602 , -1.1573448 ], dtype=float32)

If you need identical keys on all devices, use jax.random explicitly and mark the key input as static (in_axes=None).

shared_key = jax.random.PRNGKey(0)

@pmap2(axis_name='devices', in_axes=(None, 0), out_axes=0)
def sample_shared(key, scale):
    return jax.random.normal(key, ()) * scale

sample_shared(shared_key, per_device_scales)
Array([1.6226422, 1.8544483, 2.0862544, 2.3180604, 2.5498662, 2.7816722,
       3.0134785, 3.2452843], dtype=float32)

3. Relationship to StatefulMapping#

pmap creates a StatefulMapping under the hood, just like vmap. The wrapper analyzes state usage, constructs IR for the batched computation, and restores state values after every parallel execution.

parallel_mapping = pmap2(
    model,
    axis_name='devices',
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ParamState)},
    state_out_axes={0: OfType(brainstate.ParamState)},
)

print(type(parallel_mapping))
print('origin fun:', parallel_mapping.origin_fun)
print('state_in_axes:', parallel_mapping.state_in_axes)
<class 'brainstate.transform.StatefulMapping'>
origin fun: Affine(
  weight=ParamState(
    value=ShapedArray(float32[2,4])
  )
)
state_in_axes: {0: OfType(<class 'brainstate.ParamState'>)}

Advanced users can construct StatefulMapping directly, selecting their own mapping primitive. Below we recreate the earlier example but pass an explicit jax.pmap with custom donation settings.

from brainstate.transform import StatefulMapping

model = Affine(size=jax.local_device_count())

custom_pmap = StatefulMapping(
    model,
    in_axes=0,
    out_axes=0,
    state_in_axes={0: OfType(brainstate.ParamState)},
    state_out_axes={0: OfType(brainstate.ParamState)},
    axis_name='devices',
    mapping_fn=lambda fun, *a, **kw: jax.pmap(fun, donate_argnums=(0,), *a, **kw),
)

custom_pmap(jnp.ones((jax.local_device_count(), 4)))

model.weight.value
Array([[2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.],
       [2., 2., 2., 2.]], dtype=float32)

Summary#

  • brainstate.transform.pmap2 supports the full jax.pmap interface and adds state-specific controls via state_in_axes, state_out_axes, and unexpected_out_state_mapping.

  • Random states are split automatically so each device receives its own key. Use jax.random with in_axes=None to broadcast a shared key instead.

  • Like vmap, pmap returns a StatefulMapping that identifies state axis mappings and compiles the computation into a state-aware IR.