Map#

class brainstate.nn.Map(module, init_map_size, init_state_axes=None, state_tag=None, in_axes=0, out_axes=0, axis_name=None, spmd_axis_name=None, call_state_axes=None, behavior='vmap')#

Vectorize or parallelize a module using brainstate.transform.vmap2 or pmap2.

This wrapper provides enhanced control over state management during vectorized or parallel mapping operations. Unlike Vmap, ModuleMapper requires explicit initialization of vectorized states before use, enabling fine-grained control over how states are distributed across mapping axes.

The class supports two modes of operation:

  • behavior='vmap': Vectorized mapping using vmap2() for single-device batching

  • behavior='pmap': Parallel mapping using pmap2() for multi-device parallelization

Parameters:
  • module (Module) – Module to wrap with vectorized or parallel mapping.

  • init_map_size (int) – Size of the mapping axis used during state initialization. This determines how many copies of each state will be created.

  • init_state_axes (Dict[int, type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | EllipsisType | None | Tuple[type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | EllipsisType | None | Tuple[Filter, ...] | List[Filter], ...] | List[type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | EllipsisType | None | Tuple[Filter, ...] | List[Filter]]]) – Dictionary mapping axis indices to filters for state initialization. Controls how newly created states are distributed across axes. Defaults to None, which assigns all states to axis 0 except NonBatchState.

  • state_tag (str) – Tag for identifying and grouping states during vectorization. Defaults to None.

  • in_axes (int or Sequence[Any], optional) – Specification for mapping over inputs during update calls, following the semantics of jax.vmap(). Defaults to 0.

  • out_axes (Any, optional) – Specification for mapping over outputs during update calls. Defaults to 0.

  • axis_name (AxisName or None, optional) – Name of the mapped axis used by collective primitives (e.g., lax.psum). Defaults to None.

  • spmd_axis_name (AxisName or None, optional) – Name for SPMD (Single Program Multiple Data) axis when using nested mapping transforms. Defaults to None.

  • call_state_axes (Dict[int, type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | EllipsisType | None | Tuple[type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | EllipsisType | None | Tuple[Filter, ...] | List[Filter], ...] | List[type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | EllipsisType | None | Tuple[Filter, ...] | List[Filter]]]) – Dictionary mapping axes to filters for states during update calls. Specifies how states should be mapped over different axes. This is automatically integrated with states created during initialization. Defaults to None.

  • behavior (str) – Type of parallelization to use. 'vmap' for vectorized single-device mapping, 'pmap' for multi-device parallel mapping.

module#

The wrapped module being vectorized or parallelized.

Type:

Module

init_map_size#

Size of the mapping axis for state initialization.

Type:

int

dict_vmap_states#

Dictionary mapping axis indices to lists of vectorized states, populated after calling init_all_states().

Type:

dict[int, list[State]] or None

Raises:

Notes

This module requires calling init_all_states() before the first update() call. The initialization process:

  1. Calls module.init_all_states(**kwargs) under vectorized/parallel mapping

  2. Captures all newly created states

  3. Distributes states across axes based on init_state_axes

  4. Integrates these states into call_state_axes for subsequent update calls

Examples

Basic vectorized mapping:

>>> import brainstate
>>> from brainstate.nn import Map
>>> from brainstate.util.filter import OfType
>>>
>>> class MyModule(brainstate.nn.Module):
...     def init_state(self, size):
...         self.weight = brainstate.ParamState(jnp.zeros(size))
...     def update(self, x):
...         return x @ self.weight.value
>>>
>>> module = MyModule()
>>> vmapper = Map(
...     module,
...     init_map_size=10,
...     in_axes=0,
...     axis_name="batch"
... )
>>> vmapper.init_all_states(size=(5,))  # Creates 10 copies of the state
>>> outputs = vmapper.update(inputs)  # inputs.shape = (10, 5)

Parallel mapping across devices:

>>> import jax
>>> pmapper = Map(
...     module,
...     init_map_size=jax.device_count(),
...     behavior='pmap',
...     axis_name="devices"
... )
>>> pmapper.init_all_states(size=(5,))
>>> # inputs replicated across devices
>>> outputs = pmapper.update(inputs)

Mapping custom module methods:

>>> vmapper = Map(module, init_map_size=10)
>>> vmapper.init_all_states(size=(5,))
>>> # Call a specific method with custom mapping
>>> predictions = vmapper.map('predict', in_axes=0)(inputs)

See also

brainstate.transform.vmap2

Vectorized mapping with state semantics.

brainstate.transform.pmap2

Parallel mapping across devices.

brainstate.transform.vmap2_new_states

Initialize vectorized states.

brainstate.transform.pmap2_new_states

Initialize parallel states.

Vmap

Simpler vectorization wrapper without explicit state initialization.

init_all_states(**kwargs)[source]#

Initialize vectorized states for the wrapped module.

This method must be called before the first update call. It creates and configures vectorized versions of the module’s states based on the specified axis size.

map(fn, in_axes=0, out_axes=0, axis_name=None, state_axes=None)[source]#

Access the wrapped module’s methods with vectorized mapping.

This method allows you to call any method of the wrapped module with custom vectorization settings, overriding the default in_axes, out_axes, axis_name, and state_axes specified during ModuleMapper initialization.

Parameters:
  • fn (str | Callable) – The method name (as a string) or callable function to execute with vectorized mapping. If a string, it must be the name of an existing method on the wrapped module.

  • in_axes (Any) – Specification for mapping over input arguments. Can be an integer specifying which axis to map over, a tuple/dict for complex structures, or None to broadcast without mapping. Default is 0.

  • out_axes (Any) – Specification for mapping over outputs. Can be an integer specifying which axis to map over, a tuple/dict for complex structures, or None to collect outputs without mapping. Default is 0.

  • axis_name (str | None) – Name for the mapped axis used by collective operations like lax.psum or lax.pmean. If None, uses the axis name specified during ModuleMapper initialization. Default is None.

  • state_axes (Dict[int, type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | EllipsisType | None | Tuple[type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | EllipsisType | None | Tuple[Filter, ...] | List[Filter], ...] | List[type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | EllipsisType | None | Tuple[Filter, ...] | List[Filter]]]) – Dictionary mapping axis indices to state filters for fine-grained control over which states are mapped along which axes. Keys are axis indices, values are filter functions that select which states to map. If None, uses the default state mapping behavior. Default is None.

Returns:

A callable wrapper that applies the specified vectorized mapping when invoked. Call this object with the arguments you want to pass to the mapped function.

Return type:

_MapCaller

Raises:
  • ValueError – If init_all_states() has not been called before using this method.

  • AttributeError – If fn is a string but the module has no method with that name.

Examples

Basic usage with method name:

>>> import brainstate as bst
>>> import jax.numpy as jnp
>>>
>>> class MyModule(bst.nn.Module):
...     def init_state(self):
...         self.weight = bst.ParamState(jnp.ones(5))
...     def predict(self, x):
...         return x @ self.weight.value
>>>
>>> module = MyModule()
>>> vmapper = bst.nn.Map(module, init_map_size=10)
>>> vmapper.init_all_states()
>>> inputs = jnp.ones((10, 5))  # batch of 10 inputs
>>> outputs = vmapper.map('predict')(inputs)  # shape: (10,)

Using a callable function:

>>> def custom_fn(module, x, scale):
...     return module.predict(x) * scale
>>>
>>> vmapper = bst.nn.Map(module, init_map_size=10)
>>> vmapper.init_all_states()
>>> outputs = vmapper.map(lambda m, x, s: custom_fn(m, x, s))(
...     inputs, scale=2.0
... )

Custom in_axes and out_axes:

>>> class MultiInputModule(bst.nn.Module):
...     def init_state(self, size):
...         self.state = bst.State(jnp.zeros(size))
...     def process(self, x, y):
...         return x + y, x * y
>>>
>>> module = MultiInputModule()
>>> vmapper = bst.nn.Map(module, init_map_size=10)
>>> vmapper.init_all_states(size=(5,))
>>> x = jnp.ones((10, 5))  # mapped over axis 0
>>> y = jnp.ones(5)        # broadcasted (not mapped)
>>> # Map over first input but broadcast second, both outputs mapped
>>> result1, result2 = vmapper.map(
...     'process',
...     in_axes=(0, None),
...     out_axes=(0, 0)
... )(x, y)

Using state_axes for fine-grained control:

>>> from brainstate.util.filter import OfType
>>>
>>> class StatefulModule(bst.nn.Module):
...     def init_state(self, size):
...         self.params = bst.ParamState(jnp.ones(size))
...         self.buffer = bst.State(jnp.zeros(size))
...     def update(self, x):
...         self.buffer.value = x
...         return x @ self.params.value
>>>
>>> module = StatefulModule()
>>> vmapper = bst.nn.Map(module, init_map_size=10)
>>> vmapper.init_all_states(size=(5,))
>>> # Map only ParamState along axis 0, keep State shared
>>> outputs = vmapper.map(
...     'update',
...     state_axes={0: OfType(bst.ParamState)}
... )(inputs)

See also

update

Execute the vectorized module with default settings.

brainstate.transform.vmap2

Underlying vectorization transform.

brainstate.transform.pmap2

Underlying parallel mapping transform.

init_all_states

Required initialization method before using map.

update(*args, **kwargs)[source]#

Execute the vectorized module with the given arguments.

Parameters:
  • *args – Positional arguments forwarded to the vectorized module.

  • **kwargs – Keyword arguments forwarded to the vectorized module.

Returns:

Result of executing the vectorized module.

Return type:

Any

Raises:

ValueError – If init_all_states has not been called before this method.