brainstate.transform.pmap2_new_states

brainstate.transform.pmap2_new_states#

brainstate.transform.pmap2_new_states(module, init_kwargs, state_tag=None, axis_size=None, state_out_axes=None)#

Initialize and parallelize newly created states across multiple devices.

This function creates device-replicated or device-sharded versions of all states initialized by module.init_all_states(**init_kwargs). It uses pmap2() to distribute state initialization across multiple devices, enabling efficient multi-device parallelism for modules with stateful components.

The parallelization process wraps the module’s initialization in a pmap2() transform, executes it in parallel across axis_size devices, and then restores the device-distributed state values back to the original state objects. This allows subsequent operations on the module to work with device-parallelized states transparently.

Parameters:
  • module (Module) – Module whose states should be parallelized across devices. Must have an init_all_states method that creates the states to be distributed.

  • init_kwargs (Dict) – Keyword arguments forwarded to module.init_all_states(**init_kwargs) during the parallel initialization. These arguments are passed to each device’s initialization call.

  • state_tag (str) – Tag for identifying and grouping the newly created states. Used by BrainState’s state tracking system. Defaults to None.

  • axis_size (int) – Size of the parallel axis, typically the number of devices to map over. If None, defaults to the number of available devices (e.g., jax.local_device_count()).

  • state_out_axes (Dict[int, type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | Any | None | Tuple[Filter, ...] | List[Filter]]) –

    Specification for how to distribute output states across devices and axes. Can be:

    • A dictionary mapping axis indices (int) to brainstate.util.filter predicates that identify which states are distributed along which axis

    • A single filter (treated as {0: filter} for convenience)

    • None (default: all states distributed along axis 0, except NonBatchState which is replicated)

    Filters are converted to predicates via brainstate.util.filter.to_predicate(). States matching NonBatchState are automatically replicated (axis None) across all devices regardless of other specifications.

Returns:

Dictionary mapping axis indices to lists of parallelized states. Keys are the axis indices specified in state_out_axes (plus None for replicated states), and values are lists of State objects with their .value attributes set to device-distributed arrays.

Return type:

dict[int, list[State]]

Raises:

ValueError – If state assignment is ambiguous, if axis_size exceeds available devices, or if device configuration is invalid.

Notes

Initialization Process:

  1. Wraps module.init_all_states in a pmap2() transform

  2. Executes the initialization on axis_size devices in parallel

  3. Captures all newly created states using catch_new_states()

  4. Assigns states to axes based on state_out_axes predicates

  5. Restores device-distributed values to the actual state objects

  6. Adjusts state stack levels to prevent JAX tracing leakage

State Device Distribution:

States are assigned to axes (and thus distributed across devices) in priority order:

  • First, NonBatchState → axis None (replicated)

  • Then, states matching custom filters in state_out_axes

  • Finally, remaining states → axis 0 (default device-parallel axis)

Device Semantics:

  • Axis 0 (default): States are sharded across devices along the first dimension

  • Axis None: States are replicated identically on all devices

  • Custom axes: States can be sharded along different dimensions based on filters

Multi-Host Considerations:

In multi-host setups, axis_size typically corresponds to the local device count. The devices must be specified consistently across all hosts when using explicit device lists with pmap2().

Examples

Basic parallel initialization:

>>> import brainstate
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> class ParallelCounter(brainstate.nn.Module):
...     def init_state(self):
...         self.count = brainstate.ShortTermState(jnp.array(0))
>>>
>>> module = ParallelCounter()
>>> pmap_states = brainstate.transform.pmap2_new_states(
...     module,
...     init_kwargs={},
...     axis_size=jax.local_device_count()
... )
>>> module.count.value.shape
(jax.local_device_count(),)

Parallel model with device-sharded parameters:

>>> from brainstate.util.filter import OfType
>>>
>>> class ParallelModel(brainstate.nn.Module):
...     def init_state(self, layer_size):
...         self.weight = brainstate.ParamState(
...             jax.random.normal(jax.random.PRNGKey(0), (layer_size,))
...         )
...         self.bias = brainstate.ParamState(jnp.zeros(layer_size))
>>>
>>> model = ParallelModel()
>>> n_devices = jax.local_device_count()
>>> pmap_states = brainstate.transform.pmap2_new_states(
...     model,
...     init_kwargs={'layer_size': 128},
...     axis_size=n_devices,
...     state_out_axes={0: OfType(brainstate.ParamState)}
... )
>>> # Parameters are sharded across devices
>>> model.weight.value.shape
(n_devices, 128)

Mixed replicated and sharded states:

>>> class MixedParallelModule(brainstate.nn.Module):
...     def init_state(self):
...         # Sharded state (different on each device)
...         self.local_state = brainstate.ShortTermState(jnp.array(0))
...         # Replicated state (same on all devices)
...         self.global_config = brainstate.NonBatchState(
...             jnp.array([1.0, 2.0, 3.0])
...         )
>>>
>>> module = MixedParallelModule()
>>> pmap_states = brainstate.transform.pmap2_new_states(
...     module,
...     init_kwargs={},
...     axis_size=jax.local_device_count()
... )
>>> module.local_state.value.shape  # sharded
(jax.local_device_count(),)
>>> module.global_config.value.shape  # replicated (not sharded)
(3,)

Using with ModuleMapper for data parallelism:

>>> from brainstate.nn import Map
>>>
>>> model = ParallelModel()
>>> pmapper = Map(
...     model,
...     init_map_size=jax.local_device_count(),
...     behavior='pmap',
...     axis_name='devices'
... )
>>> pmapper.init_all_states(layer_size=128)
>>> # Data-parallel training across devices
>>> batch_per_device = inputs.shape[0] // jax.local_device_count()
>>> sharded_inputs = inputs.reshape(
...     jax.local_device_count(), batch_per_device, -1
... )
>>> outputs = pmapper.update(sharded_inputs)

See also

pmap2

Parallel mapping across devices with state semantics.

vmap2_new_states

Vectorized version for single-device batching.

brainstate.State

Base class for stateful objects.

brainstate.NonBatchState

Marker for states that should be replicated.

jax.pmap

Underlying JAX parallel mapping primitive.

catch_new_states

Context manager for capturing newly created states.