brainstate.transform.pmap2_new_states#
- brainstate.transform.pmap2_new_states(module, init_kwargs, state_tag=None, axis_size=None, state_out_axes=None)#
Initialize and parallelize newly created states across multiple devices.
This function creates device-replicated or device-sharded versions of all states initialized by
module.init_all_states(**init_kwargs). It usespmap2()to distribute state initialization across multiple devices, enabling efficient multi-device parallelism for modules with stateful components.The parallelization process wraps the module’s initialization in a
pmap2()transform, executes it in parallel acrossaxis_sizedevices, and then restores the device-distributed state values back to the original state objects. This allows subsequent operations on the module to work with device-parallelized states transparently.- Parameters:
module (Module) – Module whose states should be parallelized across devices. Must have an
init_all_statesmethod that creates the states to be distributed.init_kwargs (
Dict) – Keyword arguments forwarded tomodule.init_all_states(**init_kwargs)during the parallel initialization. These arguments are passed to each device’s initialization call.state_tag (
str) – Tag for identifying and grouping the newly created states. Used by BrainState’s state tracking system. Defaults toNone.axis_size (
int) – Size of the parallel axis, typically the number of devices to map over. IfNone, defaults to the number of available devices (e.g.,jax.local_device_count()).state_out_axes (
Dict[int,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|Any|None|Tuple[Filter,...] |List[Filter]]) –Specification for how to distribute output states across devices and axes. Can be:
A dictionary mapping axis indices (int) to
brainstate.util.filterpredicates that identify which states are distributed along which axisA single filter (treated as
{0: filter}for convenience)None(default: all states distributed along axis 0, exceptNonBatchStatewhich is replicated)
Filters are converted to predicates via
brainstate.util.filter.to_predicate(). States matchingNonBatchStateare automatically replicated (axisNone) across all devices regardless of other specifications.
- Returns:
Dictionary mapping axis indices to lists of parallelized states. Keys are the axis indices specified in
state_out_axes(plusNonefor replicated states), and values are lists ofStateobjects with their.valueattributes set to device-distributed arrays.- Return type:
- Raises:
ValueError – If state assignment is ambiguous, if
axis_sizeexceeds available devices, or if device configuration is invalid.
Notes
Initialization Process:
Wraps
module.init_all_statesin apmap2()transformExecutes the initialization on
axis_sizedevices in parallelCaptures all newly created states using
catch_new_states()Assigns states to axes based on
state_out_axespredicatesRestores device-distributed values to the actual state objects
Adjusts state stack levels to prevent JAX tracing leakage
State Device Distribution:
States are assigned to axes (and thus distributed across devices) in priority order:
First,
NonBatchState→ axisNone(replicated)Then, states matching custom filters in
state_out_axesFinally, remaining states → axis 0 (default device-parallel axis)
Device Semantics:
Axis 0 (default): States are sharded across devices along the first dimension
Axis
None: States are replicated identically on all devicesCustom axes: States can be sharded along different dimensions based on filters
Multi-Host Considerations:
In multi-host setups,
axis_sizetypically corresponds to the local device count. The devices must be specified consistently across all hosts when using explicit device lists withpmap2().Examples
Basic parallel initialization:
>>> import brainstate >>> import jax >>> import jax.numpy as jnp >>> >>> class ParallelCounter(brainstate.nn.Module): ... def init_state(self): ... self.count = brainstate.ShortTermState(jnp.array(0)) >>> >>> module = ParallelCounter() >>> pmap_states = brainstate.transform.pmap2_new_states( ... module, ... init_kwargs={}, ... axis_size=jax.local_device_count() ... ) >>> module.count.value.shape (jax.local_device_count(),)
Parallel model with device-sharded parameters:
>>> from brainstate.util.filter import OfType >>> >>> class ParallelModel(brainstate.nn.Module): ... def init_state(self, layer_size): ... self.weight = brainstate.ParamState( ... jax.random.normal(jax.random.PRNGKey(0), (layer_size,)) ... ) ... self.bias = brainstate.ParamState(jnp.zeros(layer_size)) >>> >>> model = ParallelModel() >>> n_devices = jax.local_device_count() >>> pmap_states = brainstate.transform.pmap2_new_states( ... model, ... init_kwargs={'layer_size': 128}, ... axis_size=n_devices, ... state_out_axes={0: OfType(brainstate.ParamState)} ... ) >>> # Parameters are sharded across devices >>> model.weight.value.shape (n_devices, 128)
Mixed replicated and sharded states:
>>> class MixedParallelModule(brainstate.nn.Module): ... def init_state(self): ... # Sharded state (different on each device) ... self.local_state = brainstate.ShortTermState(jnp.array(0)) ... # Replicated state (same on all devices) ... self.global_config = brainstate.NonBatchState( ... jnp.array([1.0, 2.0, 3.0]) ... ) >>> >>> module = MixedParallelModule() >>> pmap_states = brainstate.transform.pmap2_new_states( ... module, ... init_kwargs={}, ... axis_size=jax.local_device_count() ... ) >>> module.local_state.value.shape # sharded (jax.local_device_count(),) >>> module.global_config.value.shape # replicated (not sharded) (3,)
Using with ModuleMapper for data parallelism:
>>> from brainstate.nn import Map >>> >>> model = ParallelModel() >>> pmapper = Map( ... model, ... init_map_size=jax.local_device_count(), ... behavior='pmap', ... axis_name='devices' ... ) >>> pmapper.init_all_states(layer_size=128) >>> # Data-parallel training across devices >>> batch_per_device = inputs.shape[0] // jax.local_device_count() >>> sharded_inputs = inputs.reshape( ... jax.local_device_count(), batch_per_device, -1 ... ) >>> outputs = pmapper.update(sharded_inputs)
See also
pmap2Parallel mapping across devices with state semantics.
vmap2_new_statesVectorized version for single-device batching.
brainstate.StateBase class for stateful objects.
brainstate.NonBatchStateMarker for states that should be replicated.
jax.pmapUnderlying JAX parallel mapping primitive.
catch_new_statesContext manager for capturing newly created states.