brainstate.transform.vmap2_new_states

brainstate.transform.vmap2_new_states#

brainstate.transform.vmap2_new_states(module, init_kwargs, state_tag=None, axis_size=None, state_out_axes=None)#

Initialize and vectorize newly created states within a module.

This function creates vectorized versions of all states that are initialized when calling module.init_all_states(**init_kwargs). It uses vmap2() to create multiple copies of each state along specified axes, enabling efficient batched operations on modules with stateful components.

The vectorization process wraps the module’s initialization in a vmap2() transform, executes it in parallel across axis_size instances, and then restores the vectorized state values back to the original state objects. This allows subsequent operations on the module to work with batched states transparently.

Parameters:
  • module (Module) – Module whose states should be vectorized. Must have an init_all_states method that creates the states to be vectorized.

  • init_kwargs (Dict) – Keyword arguments forwarded to module.init_all_states(**init_kwargs) during the vectorized initialization. These arguments are passed to each parallel initialization call.

  • state_tag (str) – Tag for identifying and grouping the newly created states. Used by BrainState’s state tracking system. Defaults to None.

  • axis_size (int) – Size of the vectorization axis. Determines how many copies of each state will be created along the mapped axis. If None, the size must be inferrable from the vectorized function’s execution context.

  • state_out_axes (Dict[int, type | str | Callable[[Tuple[Key, ...], Any], bool] | bool | Any | None | Tuple[Filter, ...] | List[Filter]]) –

    Specification for how to map output states along different axes. Can be:

    • A dictionary mapping axis indices (int) to brainstate.util.filter predicates that identify which states belong to which axis

    • A single filter (treated as {0: filter} for convenience)

    • None (default: all states assigned to axis 0, except NonBatchState which goes to axis None)

    Filters are converted to predicates via brainstate.util.filter.to_predicate(). States matching NonBatchState are automatically assigned to axis None (unbatched) regardless of other specifications.

Returns:

Dictionary mapping axis indices to lists of vectorized states. Keys are the axis indices specified in state_out_axes (plus None for non-batched states), and values are lists of State objects with their .value attributes set to the vectorized arrays.

Return type:

dict[int, list[State]]

Raises:

ValueError – If state assignment is ambiguous or if axis_size cannot be inferred.

Notes

Initialization Process:

  1. Wraps module.init_all_states in a vmap2() transform

  2. Executes the initialization axis_size times in parallel

  3. Captures all newly created states using catch_new_states()

  4. Assigns states to axes based on state_out_axes predicates

  5. Restores vectorized values to the actual state objects

  6. Adjusts state stack levels to prevent JAX tracing leakage

State Axis Assignment:

States are assigned to axes in priority order:

  • First, NonBatchState → axis None (unbatched)

  • Then, states matching custom filters in state_out_axes

  • Finally, remaining states → axis 0 (default batch axis)

Critical Implementation Detail:

After restoring values, the function decreases the stack level twice on each state to prevent JAX tracing leakage. This is necessary because the vmap2() transform creates two nested state trace contexts ('vmap2_eval' and 'vmap2') that must be unwound.

Examples

Basic vectorization with default axis:

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> class Counter(brainstate.nn.Module):
...     def init_state(self):
...         self.count = brainstate.ShortTermState(jnp.array(0))
>>>
>>> module = Counter()
>>> vmap_states = brainstate.transform.vmap2_new_states(
...     module,
...     init_kwargs={},
...     axis_size=5
... )
>>> module.count.value.shape
(5,)

Custom axis assignment with filters:

>>> from brainstate.util.filter import OfType
>>>
>>> class MyModule(brainstate.nn.Module):
...     def init_state(self, size):
...         self.weight = brainstate.ParamState(jnp.zeros(size))
...         self.counter = brainstate.ShortTermState(0)
>>>
>>> module = MyModule()
>>> vmap_states = brainstate.transform.vmap2_new_states(
...     module,
...     init_kwargs={'size': 10},
...     axis_size=5,
...     state_out_axes={
...         1: OfType(brainstate.ParamState),  # weights on axis 1
...         0: OfType(brainstate.ShortTermState),  # counter on axis 0
...     }
... )
>>> module.weight.value.shape  # (size, axis_size)
(10, 5)
>>> module.counter.value.shape  # (axis_size,)
(5,)

Non-batched states:

>>> class MixedModule(brainstate.nn.Module):
...     def init_state(self):
...         self.batched = brainstate.ShortTermState(0)
...         self.shared = brainstate.NonBatchState(jnp.array([1, 2, 3]))
>>>
>>> module = MixedModule()
>>> vmap_states = brainstate.transform.vmap2_new_states(
...     module,
...     init_kwargs={},
...     axis_size=5
... )
>>> module.batched.value.shape  # batched across 5 instances
(5,)
>>> module.shared.value.shape  # not batched
(3,)

See also

vmap2

Vectorize a callable with state semantics.

pmap2_new_states

Parallel version for multi-device initialization.

brainstate.State

Base class for stateful objects.

brainstate.NonBatchState

Marker for states that should not be batched.

catch_new_states

Context manager for capturing newly created states.