Source code for brainstate.nn._common

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from collections import defaultdict
from typing import Any, Sequence, Hashable, Dict, Optional, Callable

from brainstate import environ
from brainstate._state import State
from brainstate.transform import vmap, vmap2, vmap2_new_states, pmap2, pmap2_new_states
from brainstate.typing import Filter
from brainstate.util import filter
from ._module import Module

AxisName = Hashable

__all__ = [
    'EnvironContext',
    'Vmap',
    'Map',
]


[docs] class EnvironContext(Module): """Wrap a module so it executes inside a brainstate environment context. Parameters ---------- layer : Module Module executed within the environment context. **context Keyword arguments forwarded to ``brainstate.environ.context``. Attributes ---------- layer : Module Wrapped module executed inside the context. context : dict Environment arguments applied to the wrapped module. Examples -------- .. code-block:: python >>> import brainstate >>> from brainstate.nn import EnvironContext >>> wrapped = EnvironContext(layer, fit=True) >>> result = wrapped.update(inputs) """ def __init__(self, layer: Module, **context): """Initialize the wrapper with a module and environment arguments. Parameters ---------- layer : Module Module executed inside the environment context. **context Keyword arguments forwarded to ``brainstate.environ.context``. """ super().__init__() assert isinstance(layer, Module), 'The layer must be an instance of Module.' self.layer = layer self.context = context
[docs] def update(self, *args, context: Dict = None, **kwargs): """Execute the wrapped module inside the environment context. Parameters ---------- *args Positional arguments forwarded to the wrapped module. **kwargs Keyword arguments forwarded to the wrapped module. context: dict, optional Additional environment settings for this call. Merged with the stored context. Returns ------- Any Result returned by the wrapped module. """ if context is None: context = dict() with environ.context(**self.context, **context): return self.layer(*args, **kwargs)
[docs] def add_context(self, **context): """Add more environment settings to the wrapped module. Parameters ---------- **context Keyword arguments merged into the stored environment context. """ self.context.update(context)
def _filter_states( module: Module, filters: Filter | Dict[Filter, int], ) -> Dict: """Normalize state filter specifications for ``Module.states``. Parameters ---------- module : Module Module providing the states interface. filters : Filter or dict[Filter, int] Filters passed by the caller. Dictionary keys are filters and values are the axes they should map over. Returns ------- dict[int, Any] or Any or None Structured filters to pass to ``Module.states``. Returns ``None`` when no filtering is requested. """ if filters is None: filtered_states = None elif isinstance(filters, dict): in_states_filter = defaultdict(list) for filter_, axis in filters: assert isinstance(axis, int), 'The value of in_states must be the map axis, which should be an integer.' in_states_filter[axis].append(filter_) filtered_states = module.states(*in_states_filter.values()) in_states_axis = tuple(in_states_filter.keys()) filtered_states = {axis: states for axis, states in zip(in_states_axis, filtered_states)} else: filtered_states = module.states(filters) return filtered_states
[docs] class Vmap(Module): """ Vectorize a module with ``brainstate.transform.vmap``. This wrapper applies vectorized mapping over a module, enabling efficient batch processing by automatically mapping over specified axes of inputs and states. Parameters ---------- module : Module Module to wrap with vectorized mapping. in_axes : int or None or Sequence[Any], optional Specification for mapping over inputs. Defaults to ``0``. out_axes : Any, optional Specification for mapping over outputs. Defaults to ``0``. vmap_states : Filter or Dict[int, Filter], optional Filter specifying which states should be mapped. Can be a single filter or a dictionary mapping axes (int) to filters. Defaults to ``None``. vmap_out_states : Dict[int, Dict] or Any or None, optional Specification for how to map output states. Can be a dictionary mapping axes to state specifications. Defaults to ``None``. axis_name : AxisName or None, optional Name of the axis being mapped. Defaults to ``None``. axis_size : int or None, optional Size of the mapped axis. Defaults to ``None``. Attributes ---------- module : Module The wrapped module being vectorized. vmapped_fn : Callable The vectorized function that executes the module. Examples -------- .. code-block:: python >>> from brainstate.nn import Vmap >>> vmapped = Vmap(module, in_axes=0, axis_name="batch") >>> outputs = vmapped.update(inputs) """ def __init__( self, module: Module, in_axes: int | None | Sequence[Any] = 0, out_axes: Any = 0, vmap_states: Filter | Dict[int, Filter] = None, vmap_out_states: Dict[int, Dict] | Any | None = None, axis_name: AxisName | None = None, axis_size: int | None = None, ): super().__init__() assert isinstance(module, Module), 'The module must be an instance of Module.' self.in_axes = in_axes self.out_axes = out_axes self.axis_name = axis_name self.axis_size = axis_size self.module = module vmap_states = _filter_states(module, vmap_states) vmap_out_states = _filter_states(module, vmap_out_states) @vmap( in_axes=in_axes, out_axes=out_axes, in_states=vmap_states, out_states=vmap_out_states, axis_name=axis_name, axis_size=axis_size, ) def vmap_run(*args, **kwargs): return module(*args, **kwargs) # vmapped module self.vmapped_fn = vmap_run
[docs] def update(self, *args, **kwargs): """Execute the vmapped module with the given arguments. Parameters ---------- *args Positional arguments forwarded to the vmapped module. **kwargs Keyword arguments forwarded to the vmapped module. Returns ------- Any Result of executing the vmapped module. """ return self.vmapped_fn(*args, **kwargs)
class ToPredicate: """Helper predicate class for filtering states by identity. This class creates a predicate that matches states based on their object identity (id), used internally for state filtering in vectorized mapping. Parameters ---------- states : Iterable[State] Collection of states to match against. Attributes ---------- state_ids : set Set of state object IDs for efficient lookup. """ def __init__(self, states): self.state_ids = set([id(st) for st in states]) def __call__(self, path, st: State): """Check if a state matches the predicate. Parameters ---------- path : Any Path to the state (unused). st : State State to check. Returns ------- bool True if the state's ID is in the predicate's state set. """ return id(st) in self.state_ids class _MapCaller: def __init__( self, fn: Callable, behavior: str, in_axes: Any = 0, out_axes: Any = 0, axis_name: Optional[str] = None, state_axes: Dict[int, Filter] = None, ): self.in_axes = in_axes self.out_axes = out_axes self.axis_name = axis_name self.state_axes = state_axes self.behavior = behavior if behavior == 'vmap': map_fn = vmap2 elif behavior == 'pmap': map_fn = pmap2 else: raise ValueError( 'Invalid behavior specified. Must be "vmap" or "pmap".' ) self.map_fn = map_fn( fn, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, state_in_axes=state_axes, state_out_axes=state_axes, ) def __call__(self, *args, **kwargs): return self.map_fn(*args, **kwargs) class Map(Module): """ Vectorize or parallelize a module using ``brainstate.transform.vmap2`` or ``pmap2``. This wrapper provides enhanced control over state management during vectorized or parallel mapping operations. Unlike ``Vmap``, ``ModuleMapper`` requires explicit initialization of vectorized states before use, enabling fine-grained control over how states are distributed across mapping axes. The class supports two modes of operation: - ``behavior='vmap'``: Vectorized mapping using :func:`vmap2` for single-device batching - ``behavior='pmap'``: Parallel mapping using :func:`pmap2` for multi-device parallelization Parameters ---------- module : Module Module to wrap with vectorized or parallel mapping. init_map_size : int Size of the mapping axis used during state initialization. This determines how many copies of each state will be created. init_state_axes : Dict[int, Filter], optional Dictionary mapping axis indices to filters for state initialization. Controls how newly created states are distributed across axes. Defaults to ``None``, which assigns all states to axis 0 except :class:`NonBatchState`. state_tag : str, optional Tag for identifying and grouping states during vectorization. Defaults to ``None``. in_axes : int or Sequence[Any], optional Specification for mapping over inputs during ``update`` calls, following the semantics of :func:`jax.vmap`. Defaults to ``0``. out_axes : Any, optional Specification for mapping over outputs during ``update`` calls. Defaults to ``0``. axis_name : AxisName or None, optional Name of the mapped axis used by collective primitives (e.g., ``lax.psum``). Defaults to ``None``. spmd_axis_name : AxisName or None, optional Name for SPMD (Single Program Multiple Data) axis when using nested mapping transforms. Defaults to ``None``. call_state_axes : Dict[int, Filter], optional Dictionary mapping axes to filters for states during ``update`` calls. Specifies how states should be mapped over different axes. This is automatically integrated with states created during initialization. Defaults to ``None``. behavior : {'vmap', 'pmap'}, default 'vmap' Type of parallelization to use. ``'vmap'`` for vectorized single-device mapping, ``'pmap'`` for multi-device parallel mapping. Attributes ---------- module : Module The wrapped module being vectorized or parallelized. init_map_size : int Size of the mapping axis for state initialization. dict_vmap_states : dict[int, list[State]] or None Dictionary mapping axis indices to lists of vectorized states, populated after calling :meth:`init_all_states`. Raises ------ ValueError If ``behavior`` is not ``'vmap'`` or ``'pmap'``, or if ``update`` is called before :meth:`init_all_states`. AssertionError If ``init_map_size`` is not an integer. Notes ----- This module requires calling :meth:`init_all_states` before the first :meth:`update` call. The initialization process: 1. Calls ``module.init_all_states(**kwargs)`` under vectorized/parallel mapping 2. Captures all newly created states 3. Distributes states across axes based on ``init_state_axes`` 4. Integrates these states into ``call_state_axes`` for subsequent ``update`` calls Examples -------- **Basic vectorized mapping:** .. code-block:: python >>> import brainstate >>> from brainstate.nn import Map >>> from brainstate.util.filter import OfType >>> >>> class MyModule(brainstate.nn.Module): ... def init_state(self, size): ... self.weight = brainstate.ParamState(jnp.zeros(size)) ... def update(self, x): ... return x @ self.weight.value >>> >>> module = MyModule() >>> vmapper = Map( ... module, ... init_map_size=10, ... in_axes=0, ... axis_name="batch" ... ) >>> vmapper.init_all_states(size=(5,)) # Creates 10 copies of the state >>> outputs = vmapper.update(inputs) # inputs.shape = (10, 5) **Parallel mapping across devices:** .. code-block:: python >>> import jax >>> pmapper = Map( ... module, ... init_map_size=jax.device_count(), ... behavior='pmap', ... axis_name="devices" ... ) >>> pmapper.init_all_states(size=(5,)) >>> # inputs replicated across devices >>> outputs = pmapper.update(inputs) **Mapping custom module methods:** .. code-block:: python >>> vmapper = Map(module, init_map_size=10) >>> vmapper.init_all_states(size=(5,)) >>> # Call a specific method with custom mapping >>> predictions = vmapper.map('predict', in_axes=0)(inputs) See Also -------- brainstate.transform.vmap2 : Vectorized mapping with state semantics. brainstate.transform.pmap2 : Parallel mapping across devices. brainstate.transform.vmap2_new_states : Initialize vectorized states. brainstate.transform.pmap2_new_states : Initialize parallel states. Vmap : Simpler vectorization wrapper without explicit state initialization. """ __module__ = 'brainstate.nn' def __init__( self, module: 'Module', # vmap parameters for init_all_states init_map_size: int, init_state_axes: Dict[int, Filter] = None, state_tag: str = None, # vmap parameters for update calls in_axes=0, out_axes=0, axis_name=None, spmd_axis_name=None, call_state_axes: Dict[int, Filter] = None, # type to parallelize behavior: str = 'vmap', ): super().__init__() assert isinstance(init_map_size, int), 'init_map_size must be an integer.' assert behavior in ['vmap', 'pmap'], 'behavior must be either "vmap" or "pmap".' self.init_map_size = init_map_size self.module = module self.state_tag = state_tag self.in_axes = in_axes self.out_axes = out_axes self.axis_name = axis_name self.spmd_axis_name = spmd_axis_name self.call_state_axes = call_state_axes self.init_state_axes = init_state_axes self.dict_vmap_states = None self.behavior = behavior self._init = False self._call_state_axes = None def __pretty_repr_item__(self, name, value): if name in [ '_init', 'dict_vmap_states', '_call_state_axes' ]: return None return name, value def _integrate_state_axes(self, call_state_axes): if call_state_axes is None: call_state_axes = dict() call_state_axes = dict(call_state_axes) for k, v in tuple(call_state_axes.items()): if k in self.dict_vmap_states: call_state_axes[k] = filter.Any(v, ToPredicate(self.dict_vmap_states[k])) for k, v in self.dict_vmap_states.items(): if k not in call_state_axes: call_state_axes[k] = ToPredicate(v) return call_state_axes def init_all_states(self, **kwargs): """Initialize vectorized states for the wrapped module. This method must be called before the first ``update`` call. It creates and configures vectorized versions of the module's states based on the specified axis size. """ if self.behavior == 'vmap': map_fn = vmap2_new_states elif self.behavior == 'pmap': map_fn = pmap2_new_states else: raise ValueError( 'Invalid behavior specified. Must be "vmap" or "pmap".' ) self.dict_vmap_states = map_fn( self.module, kwargs, state_tag=self.state_tag, axis_size=self.init_map_size, state_out_axes=self.init_state_axes, ) self._call_state_axes = self._integrate_state_axes(self.call_state_axes) self._init = True def update(self, *args, **kwargs): """Execute the vectorized module with the given arguments. Parameters ---------- *args Positional arguments forwarded to the vectorized module. **kwargs Keyword arguments forwarded to the vectorized module. Returns ------- Any Result of executing the vectorized module. Raises ------ ValueError If ``init_all_states`` has not been called before this method. """ if not self._init: raise ValueError( 'Map.update called before init_all_states. Please call init_all_states first.' ) if self.behavior == 'vmap': map_fn = vmap2 elif self.behavior == 'pmap': map_fn = pmap2 else: raise ValueError( 'Invalid behavior specified. Must be "vmap" or "pmap".' ) return map_fn( self.module, in_axes=self.in_axes, out_axes=self.out_axes, axis_name=self.axis_name, spmd_axis_name=self.spmd_axis_name, state_in_axes=self._call_state_axes, state_out_axes=self._call_state_axes, )(*args, **kwargs) def map( self, fn: str | Callable, in_axes: Any = 0, out_axes: Any = 0, axis_name: Optional[str] = None, state_axes: Dict[int, Filter] = None, ) -> _MapCaller: """ Access the wrapped module's methods with vectorized mapping. This method allows you to call any method of the wrapped module with custom vectorization settings, overriding the default ``in_axes``, ``out_axes``, ``axis_name``, and ``state_axes`` specified during ``ModuleMapper`` initialization. Parameters ---------- fn : str or Callable The method name (as a string) or callable function to execute with vectorized mapping. If a string, it must be the name of an existing method on the wrapped module. in_axes : Any, optional Specification for mapping over input arguments. Can be an integer specifying which axis to map over, a tuple/dict for complex structures, or ``None`` to broadcast without mapping. Default is ``0``. out_axes : Any, optional Specification for mapping over outputs. Can be an integer specifying which axis to map over, a tuple/dict for complex structures, or ``None`` to collect outputs without mapping. Default is ``0``. axis_name : str, optional Name for the mapped axis used by collective operations like ``lax.psum`` or ``lax.pmean``. If ``None``, uses the axis name specified during ``ModuleMapper`` initialization. Default is ``None``. state_axes : Dict[int, Filter], optional Dictionary mapping axis indices to state filters for fine-grained control over which states are mapped along which axes. Keys are axis indices, values are filter functions that select which states to map. If ``None``, uses the default state mapping behavior. Default is ``None``. Returns ------- _MapCaller A callable wrapper that applies the specified vectorized mapping when invoked. Call this object with the arguments you want to pass to the mapped function. Raises ------ ValueError If ``init_all_states()`` has not been called before using this method. AttributeError If ``fn`` is a string but the module has no method with that name. Examples -------- **Basic usage with method name:** .. code-block:: python >>> import brainstate as bst >>> import jax.numpy as jnp >>> >>> class MyModule(bst.nn.Module): ... def init_state(self): ... self.weight = bst.ParamState(jnp.ones(5)) ... def predict(self, x): ... return x @ self.weight.value >>> >>> module = MyModule() >>> vmapper = bst.nn.Map(module, init_map_size=10) >>> vmapper.init_all_states() >>> inputs = jnp.ones((10, 5)) # batch of 10 inputs >>> outputs = vmapper.map('predict')(inputs) # shape: (10,) **Using a callable function:** .. code-block:: python >>> def custom_fn(module, x, scale): ... return module.predict(x) * scale >>> >>> vmapper = bst.nn.Map(module, init_map_size=10) >>> vmapper.init_all_states() >>> outputs = vmapper.map(lambda m, x, s: custom_fn(m, x, s))( ... inputs, scale=2.0 ... ) **Custom in_axes and out_axes:** .. code-block:: python >>> class MultiInputModule(bst.nn.Module): ... def init_state(self, size): ... self.state = bst.State(jnp.zeros(size)) ... def process(self, x, y): ... return x + y, x * y >>> >>> module = MultiInputModule() >>> vmapper = bst.nn.Map(module, init_map_size=10) >>> vmapper.init_all_states(size=(5,)) >>> x = jnp.ones((10, 5)) # mapped over axis 0 >>> y = jnp.ones(5) # broadcasted (not mapped) >>> # Map over first input but broadcast second, both outputs mapped >>> result1, result2 = vmapper.map( ... 'process', ... in_axes=(0, None), ... out_axes=(0, 0) ... )(x, y) **Using state_axes for fine-grained control:** .. code-block:: python >>> from brainstate.util.filter import OfType >>> >>> class StatefulModule(bst.nn.Module): ... def init_state(self, size): ... self.params = bst.ParamState(jnp.ones(size)) ... self.buffer = bst.State(jnp.zeros(size)) ... def update(self, x): ... self.buffer.value = x ... return x @ self.params.value >>> >>> module = StatefulModule() >>> vmapper = bst.nn.Map(module, init_map_size=10) >>> vmapper.init_all_states(size=(5,)) >>> # Map only ParamState along axis 0, keep State shared >>> outputs = vmapper.map( ... 'update', ... state_axes={0: OfType(bst.ParamState)} ... )(inputs) See Also -------- update : Execute the vectorized module with default settings. brainstate.transform.vmap2 : Underlying vectorization transform. brainstate.transform.pmap2 : Underlying parallel mapping transform. init_all_states : Required initialization method before using map. """ if isinstance(fn, str): try: fn = getattr(self.module, fn) except AttributeError: raise AttributeError(f'Module has no method named {fn}.') from None assert callable(fn), 'fn must be a callable or the name of a method.' if not self._init: raise ValueError( 'Map.update called before init_all_states. Please call init_all_states first.' ) return _MapCaller( fn, behavior=self.behavior, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, state_axes=self._integrate_state_axes(state_axes), )