Map#
- class brainstate.nn.Map(module, init_map_size, init_state_axes=None, state_tag=None, in_axes=0, out_axes=0, axis_name=None, spmd_axis_name=None, call_state_axes=None, behavior='vmap')#
Vectorize or parallelize a module using
brainstate.transform.vmap2orpmap2.This wrapper provides enhanced control over state management during vectorized or parallel mapping operations. Unlike
Vmap,ModuleMapperrequires explicit initialization of vectorized states before use, enabling fine-grained control over how states are distributed across mapping axes.The class supports two modes of operation:
behavior='vmap': Vectorized mapping usingvmap2()for single-device batchingbehavior='pmap': Parallel mapping usingpmap2()for multi-device parallelization
- Parameters:
module (
Module) – Module to wrap with vectorized or parallel mapping.init_map_size (
int) – Size of the mapping axis used during state initialization. This determines how many copies of each state will be created.init_state_axes (
Dict[int,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|EllipsisType|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|EllipsisType|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|EllipsisType|None|Tuple[Filter,...] |List[Filter]]]) – Dictionary mapping axis indices to filters for state initialization. Controls how newly created states are distributed across axes. Defaults toNone, which assigns all states to axis 0 exceptNonBatchState.state_tag (
str) – Tag for identifying and grouping states during vectorization. Defaults toNone.in_axes (int or Sequence[Any], optional) – Specification for mapping over inputs during
updatecalls, following the semantics ofjax.vmap(). Defaults to0.out_axes (Any, optional) – Specification for mapping over outputs during
updatecalls. Defaults to0.axis_name (AxisName or None, optional) – Name of the mapped axis used by collective primitives (e.g.,
lax.psum). Defaults toNone.spmd_axis_name (AxisName or None, optional) – Name for SPMD (Single Program Multiple Data) axis when using nested mapping transforms. Defaults to
None.call_state_axes (
Dict[int,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|EllipsisType|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|EllipsisType|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|EllipsisType|None|Tuple[Filter,...] |List[Filter]]]) – Dictionary mapping axes to filters for states duringupdatecalls. Specifies how states should be mapped over different axes. This is automatically integrated with states created during initialization. Defaults toNone.behavior (
str) – Type of parallelization to use.'vmap'for vectorized single-device mapping,'pmap'for multi-device parallel mapping.
- dict_vmap_states#
Dictionary mapping axis indices to lists of vectorized states, populated after calling
init_all_states().
- Raises:
ValueError – If
behavioris not'vmap'or'pmap', or ifupdateis called beforeinit_all_states().AssertionError – If
init_map_sizeis not an integer.
Notes
This module requires calling
init_all_states()before the firstupdate()call. The initialization process:Calls
module.init_all_states(**kwargs)under vectorized/parallel mappingCaptures all newly created states
Distributes states across axes based on
init_state_axesIntegrates these states into
call_state_axesfor subsequentupdatecalls
Examples
Basic vectorized mapping:
>>> import brainstate >>> from brainstate.nn import Map >>> from brainstate.util.filter import OfType >>> >>> class MyModule(brainstate.nn.Module): ... def init_state(self, size): ... self.weight = brainstate.ParamState(jnp.zeros(size)) ... def update(self, x): ... return x @ self.weight.value >>> >>> module = MyModule() >>> vmapper = Map( ... module, ... init_map_size=10, ... in_axes=0, ... axis_name="batch" ... ) >>> vmapper.init_all_states(size=(5,)) # Creates 10 copies of the state >>> outputs = vmapper.update(inputs) # inputs.shape = (10, 5)
Parallel mapping across devices:
>>> import jax >>> pmapper = Map( ... module, ... init_map_size=jax.device_count(), ... behavior='pmap', ... axis_name="devices" ... ) >>> pmapper.init_all_states(size=(5,)) >>> # inputs replicated across devices >>> outputs = pmapper.update(inputs)
Mapping custom module methods:
>>> vmapper = Map(module, init_map_size=10) >>> vmapper.init_all_states(size=(5,)) >>> # Call a specific method with custom mapping >>> predictions = vmapper.map('predict', in_axes=0)(inputs)
See also
brainstate.transform.vmap2Vectorized mapping with state semantics.
brainstate.transform.pmap2Parallel mapping across devices.
brainstate.transform.vmap2_new_statesInitialize vectorized states.
brainstate.transform.pmap2_new_statesInitialize parallel states.
VmapSimpler vectorization wrapper without explicit state initialization.
- init_all_states(**kwargs)[source]#
Initialize vectorized states for the wrapped module.
This method must be called before the first
updatecall. It creates and configures vectorized versions of the module’s states based on the specified axis size.
- map(fn, in_axes=0, out_axes=0, axis_name=None, state_axes=None)[source]#
Access the wrapped module’s methods with vectorized mapping.
This method allows you to call any method of the wrapped module with custom vectorization settings, overriding the default
in_axes,out_axes,axis_name, andstate_axesspecified duringModuleMapperinitialization.- Parameters:
fn (
str|Callable) – The method name (as a string) or callable function to execute with vectorized mapping. If a string, it must be the name of an existing method on the wrapped module.in_axes (
Any) – Specification for mapping over input arguments. Can be an integer specifying which axis to map over, a tuple/dict for complex structures, orNoneto broadcast without mapping. Default is0.out_axes (
Any) – Specification for mapping over outputs. Can be an integer specifying which axis to map over, a tuple/dict for complex structures, orNoneto collect outputs without mapping. Default is0.axis_name (
str|None) – Name for the mapped axis used by collective operations likelax.psumorlax.pmean. IfNone, uses the axis name specified duringModuleMapperinitialization. Default isNone.state_axes (
Dict[int,type|str|Callable[[Tuple[Key,...],Any],bool] |bool|EllipsisType|None|Tuple[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|EllipsisType|None|Tuple[Filter,...] |List[Filter],...] |List[type|str|Callable[[Tuple[Key,...],Any],bool] |bool|EllipsisType|None|Tuple[Filter,...] |List[Filter]]]) – Dictionary mapping axis indices to state filters for fine-grained control over which states are mapped along which axes. Keys are axis indices, values are filter functions that select which states to map. IfNone, uses the default state mapping behavior. Default isNone.
- Returns:
A callable wrapper that applies the specified vectorized mapping when invoked. Call this object with the arguments you want to pass to the mapped function.
- Return type:
_MapCaller- Raises:
ValueError – If
init_all_states()has not been called before using this method.AttributeError – If
fnis a string but the module has no method with that name.
Examples
Basic usage with method name:
>>> import brainstate as bst >>> import jax.numpy as jnp >>> >>> class MyModule(bst.nn.Module): ... def init_state(self): ... self.weight = bst.ParamState(jnp.ones(5)) ... def predict(self, x): ... return x @ self.weight.value >>> >>> module = MyModule() >>> vmapper = bst.nn.Map(module, init_map_size=10) >>> vmapper.init_all_states() >>> inputs = jnp.ones((10, 5)) # batch of 10 inputs >>> outputs = vmapper.map('predict')(inputs) # shape: (10,)
Using a callable function:
>>> def custom_fn(module, x, scale): ... return module.predict(x) * scale >>> >>> vmapper = bst.nn.Map(module, init_map_size=10) >>> vmapper.init_all_states() >>> outputs = vmapper.map(lambda m, x, s: custom_fn(m, x, s))( ... inputs, scale=2.0 ... )
Custom in_axes and out_axes:
>>> class MultiInputModule(bst.nn.Module): ... def init_state(self, size): ... self.state = bst.State(jnp.zeros(size)) ... def process(self, x, y): ... return x + y, x * y >>> >>> module = MultiInputModule() >>> vmapper = bst.nn.Map(module, init_map_size=10) >>> vmapper.init_all_states(size=(5,)) >>> x = jnp.ones((10, 5)) # mapped over axis 0 >>> y = jnp.ones(5) # broadcasted (not mapped) >>> # Map over first input but broadcast second, both outputs mapped >>> result1, result2 = vmapper.map( ... 'process', ... in_axes=(0, None), ... out_axes=(0, 0) ... )(x, y)
Using state_axes for fine-grained control:
>>> from brainstate.util.filter import OfType >>> >>> class StatefulModule(bst.nn.Module): ... def init_state(self, size): ... self.params = bst.ParamState(jnp.ones(size)) ... self.buffer = bst.State(jnp.zeros(size)) ... def update(self, x): ... self.buffer.value = x ... return x @ self.params.value >>> >>> module = StatefulModule() >>> vmapper = bst.nn.Map(module, init_map_size=10) >>> vmapper.init_all_states(size=(5,)) >>> # Map only ParamState along axis 0, keep State shared >>> outputs = vmapper.map( ... 'update', ... state_axes={0: OfType(bst.ParamState)} ... )(inputs)
See also
updateExecute the vectorized module with default settings.
brainstate.transform.vmap2Underlying vectorization transform.
brainstate.transform.pmap2Underlying parallel mapping transform.
init_all_statesRequired initialization method before using map.
- update(*args, **kwargs)[source]#
Execute the vectorized module with the given arguments.
- Parameters:
*args – Positional arguments forwarded to the vectorized module.
**kwargs – Keyword arguments forwarded to the vectorized module.
- Returns:
Result of executing the vectorized module.
- Return type:
- Raises:
ValueError – If
init_all_stateshas not been called before this method.