Source code for brainstate.nn._collective_ops

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import warnings
from collections.abc import Sequence, Mapping
from typing import Callable, TypeVar, Any, Dict

import jax

from brainstate._state import catch_new_states
from brainstate._utils import set_module_as
from brainstate.graph import nodes
from brainstate.transform import vmap, vmap_new_states
from brainstate.typing import Filter
from ._module import Module

# the maximum order
MAX_ORDER = 10

T = TypeVar('T', bound=Module)

__all__ = [
    'call_order',
    'call_all_fns',
    'vmap_call_all_fns',
    'init_all_states',
    'vmap_init_all_states',
    'reset_all_states',
    'vmap_reset_all_states',
    'assign_state_values',
]


@set_module_as('brainstate.nn')
def call_order(
    level: int = 0,
    check_order_boundary: bool = True
) -> Callable[[Callable], Callable]:
    """
    Decorator for specifying the execution order of functions in collective operations.

    This decorator attaches a `call_order` attribute to a function, which is used by
    collective operations like `call_all_functions`, `init_all_states`, and `reset_all_states`
    to determine the execution order. Functions with lower order levels are executed first.

    Parameters
    ----------
    level : int, optional
        The execution order level. Lower values indicate earlier execution.
        Must be in the range [0, MAX_ORDER) when `check_order_boundary` is True.
        Default is 0.
    check_order_boundary : bool, optional
        Whether to validate that the order level is within the valid range [0, MAX_ORDER).
        Default is True.

    Returns
    -------
    Callable[[Callable], Callable]
        A decorator function that adds the `call_order` attribute to the decorated function.

    Raises
    ------
    ValueError
        If `check_order_boundary` is True and `level` is not in [0, MAX_ORDER).

    Examples
    --------
    .. code-block:: python

        >>> import brainstate
        >>>
        >>> class MyModule(brainstate.nn.Module):
        ...     @brainstate.nn.call_order(0)
        ...     def reset_state(self):
        ...         print("Reset first")
        ...
        ...     @brainstate.nn.call_order(1)
        ...     def another_reset(self):
        ...         print("Reset second")
    """
    if check_order_boundary and (level < 0 or level >= MAX_ORDER):
        raise ValueError(f'"level" must be an integer in [0, {MAX_ORDER}), but got {level}.')

    def wrap(fun: Callable) -> Callable:
        fun.call_order = level
        return fun

    return wrap


@set_module_as('brainstate.nn')
def call_all_fns(
    target: T,
    fn_name: str,
    args: Sequence[Any] | Any = (),
    kwargs: Mapping[str, Any] | None = None,
    node_to_exclude: Filter = None,
    fn_if_not_exist: str = 'raise',
) -> T:
    """
    Call a specified function on all module nodes within a target, respecting call order.

    This function traverses all module nodes in the target and invokes the specified method
    on each node. Functions decorated with `@call_order()` are executed in ascending order
    of their level values, while functions without the decorator are executed first.

    Parameters
    ----------
    target : Module
        The target module on which to call functions.
    fn_name : str
        The name of the method to call on each module node.
    node_to_exclude : Filter, optional
        A filter to exclude certain nodes from the function call.
        Can be a type, predicate function, or any filter supported by the graph API.
    fn_if_not_exist : str, optional
        Behavior when the specified method doesn't exist on a node:

        - 'raise': Raise an AttributeError (default)
        - 'pass' or 'none': Skip the node silently
        - 'warn': Issue a warning and skip the node
    args
        Positional arguments to pass to the called method. A single non-tuple
        argument will be automatically wrapped in a tuple. Default is ().
    kwargs
        Keyword arguments to pass to the called method. Default is None.

    Raises
    ------
    TypeError
        If `fun_name` is not a string or `kwargs` is not a mapping.
    ValueError
        If `fn_if_not_exist` is not one of the allowed values.
    AttributeError
        If the specified method doesn't exist on a node and `fn_if_not_exist` is 'raise'.

    Examples
    --------
    .. code-block:: python

        >>> import brainstate
        >>>
        >>> net = brainstate.nn.Sequential(brainstate.nn.Linear(10, 20), brainstate.nn.ReLU())
        >>> brainstate.nn.call_all_fns(net, 'init_state')
    """
    if not isinstance(fn_name, str):
        raise TypeError(f'fn_name must be a string, but got {type(fn_name).__name__}.')

    args = (args,) if not isinstance(args, tuple) else args
    kwargs = kwargs or {}
    if not isinstance(kwargs, Mapping):
        raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.')

    all_nodes = nodes(target).filter(Module)
    if node_to_exclude is not None:
        all_nodes -= all_nodes.filter(node_to_exclude)

    # Separate nodes with and without call_order
    nodes_with_order = []
    for path, node in all_nodes.items():
        try:
            fun = getattr(node, fn_name)
        except AttributeError as e:
            if fn_if_not_exist == 'raise':
                raise AttributeError(
                    f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'"
                ) from e
            elif fn_if_not_exist in ('pass', 'none'):
                continue
            elif fn_if_not_exist == 'warn':
                warnings.warn(
                    f"Module {type(node).__name__} with the path {path} does not have method '{fn_name}'. "
                    f"Skipping.",
                    UserWarning
                )
                continue
            else:
                raise ValueError(
                    f"fn_if_not_exist must be one of ['raise', 'pass', 'none'], but got '{fn_if_not_exist}'."
                )

        if not callable(fun):
            raise TypeError(f"'{fn_name}' must be callable, but got {type(fun).__name__}.")

        if hasattr(fun, 'call_order'):
            nodes_with_order.append(node)
        else:
            fun(*args, **kwargs)

    # Execute nodes with call_order in sorted order
    for node in sorted(nodes_with_order, key=lambda x: getattr(x, fn_name).call_order):
        getattr(node, fn_name)(*args, **kwargs)
    return target


[docs] def vmap_call_all_fns( target: T, fn_name: str, args: Sequence[Any] | Any = (), kwargs: Mapping[str, Any] | None = None, axis_size: int = None, node_to_exclude: Filter = None, state_tag: str | None = None, fn_if_not_exist: str = 'raise', ) -> T: """ Apply vectorized mapping to call a function on all module nodes with batched state handling. This function creates multiple batched instances by applying vmap to the specified method call across all module nodes. Each batch element maintains its own random key and state values. This is particularly useful for creating ensembles or batched models. Parameters ---------- target : Module The target module on which to call functions. fn_name : str The name of the method to call on each module node. args : Sequence[Any] or Any, optional Positional arguments to pass to the called method. A single non-tuple argument will be automatically wrapped in a tuple. Default is (). kwargs : Mapping[str, Any], optional Keyword arguments to pass to the called method. Default is None. axis_size : int The size of the batch dimension for vmap. Must be a positive integer. node_to_exclude : Filter, optional A filter to exclude certain nodes from the function call. state_tag : str, optional An optional tag to categorize newly created states during the vmap operation. fn_if_not_exist : str, optional Behavior when the specified method doesn't exist on a node: - 'raise': Raise an AttributeError (default) - 'pass' or 'none': Skip the node silently - 'warn': Issue a warning and skip the node Raises ------ ValueError If `axis_size` is None or not a positive integer. TypeError If `kwargs` is not a mapping. Examples -------- .. code-block:: python >>> import brainstate >>> >>> net = brainstate.nn.Linear(10, 20) >>> # Create 5 batched instances with different initializations >>> brainstate.nn.vmap_call_all_fns(net, 'init_state', axis_size=5) """ if axis_size is None or axis_size <= 0: raise ValueError(f"axis_size must be a positive integer, got {axis_size}") if not isinstance(args, tuple): args = (args,) kwargs = kwargs or {} if not isinstance(kwargs, Mapping): raise TypeError(f'kwargs must be a mapping, but got {type(kwargs).__name__}.') @vmap(axis_size=axis_size) def vmapped_fn(): with catch_new_states(state_tag) as inner_catcher: call_all_fns( target, fn_name=fn_name, args=args, kwargs=kwargs, node_to_exclude=node_to_exclude, fn_if_not_exist=fn_if_not_exist ) return inner_catcher.get_state_values() with catch_new_states(state_tag) as outer_catcher: values = vmapped_fn() states = outer_catcher.get_states() for state, value in zip(states, values): state.value = value return target
@set_module_as('brainstate.nn') def init_all_states( target: T, *init_args, node_to_exclude: Filter = None, **init_kwargs, ) -> T: """ Initialize states for all module nodes within the target. This is a convenience wrapper around `call_all_functions` that specifically calls the `init_state` method on all module nodes. The execution order respects any `@call_order()` decorators on the `init_state` methods. Parameters ---------- target : Module The target module whose states are to be initialized. *init_args Variable positional arguments to pass to each `init_state` method. node_to_exclude : Filter, optional A filter to exclude certain nodes from initialization. Can be a type, predicate function, or any filter supported by the graph API. **init_kwargs Variable keyword arguments to pass to each `init_state` method. Examples -------- .. code-block:: python >>> import brainstate >>> >>> net = brainstate.nn.Sequential( ... brainstate.nn.Linear(10, 20), ... brainstate.nn.Dropout(0.5) ... ) >>> # Initialize all states >>> brainstate.nn.init_all_states(net) >>> >>> # Initialize with custom arguments >>> brainstate.nn.init_all_states(net, batch_size=32) See Also -------- call_all_functions : The underlying function that executes the calls. vmap_init_all_states : Vectorized version for batched initialization. """ call_all_fns(target, 'init_state', init_args, init_kwargs, node_to_exclude) return target @set_module_as('brainstate.nn') def vmap_init_all_states( target: T, *init_args, axis_size: int = None, node_to_exclude: Filter = None, state_to_exclude: Filter = None, state_tag: str | None = None, in_states: Dict[int, Dict] | Any | None = None, out_states: Dict[int, Dict] | Any | None = None, **init_kwargs ) -> T: """ Initialize states with vectorized mapping for creating batched module instances. This function applies vmap to the initialization process, creating multiple batched instances of module states. Each batch element will have independent state values and random keys. This is useful for ensemble models or parameter sweeps. Parameters ---------- target : Module The target module whose states are to be initialized. *init_args Variable positional arguments to pass to each `init_state` method. axis_size : int The size of the batch dimension. Must be a positive integer. node_to_exclude : Filter, optional A filter to exclude certain nodes from initialization. state_to_exclude : Filter, optional A filter to exclude certain states from being vmapped. Excluded states will remain shared across all batched instances. state_tag : str, optional An optional tag to categorize newly created states. **init_kwargs Variable keyword arguments to pass to each `init_state` method. Raises ------ ValueError If `axis_size` is None or not a positive integer. Examples -------- .. code-block:: python >>> import brainstate >>> >>> net = brainstate.nn.Linear(10, 20) >>> # Create 8 batched instances with different random initializations >>> brainstate.nn.vmap_init_all_states(net, axis_size=8) >>> >>> # The weight parameter now has shape (8, 20, 10) instead of (20, 10) >>> print(net.weight.shape) See Also -------- init_all_states : Non-vectorized version. vmap_new_states : The underlying vmap transformation for states. """ # vmap_call_all_functions( # target, # fun_name='init_state', # args=init_args, # kwargs=init_kwargs, # axis_size=axis_size, # node_to_exclude=node_to_exclude, # state_tag=state_tag, # ) def init_fn(): init_all_states( target, *init_args, **init_kwargs, node_to_exclude=node_to_exclude, ) return vmap_new_states( init_fn, state_tag=state_tag, axis_size=axis_size, state_to_exclude=state_to_exclude, in_states=in_states, out_states=out_states, )() return target @set_module_as('brainstate.nn') def reset_all_states( target: T, *reset_args, node_to_exclude: Filter = None, **reset_kwargs, ) -> T: """ Reset states for all module nodes within the target. This is a convenience wrapper around `call_all_functions` that specifically calls the `reset_state` method on all module nodes. The execution order respects any `@call_order()` decorators on the `reset_state` methods. This is typically used to reset recurrent neural network states between sequences. Parameters ---------- target : Module The target module whose states are to be reset. reset_args Positional arguments to pass to each `reset_state` method. A single non-tuple argument will be automatically wrapped in a tuple. Default is (). reset_kwargs Keyword arguments to pass to each `reset_state` method. Default is None. node_to_exclude : Filter, optional A filter to exclude certain nodes from reset. Can be a type, predicate function, or any filter supported by the graph API. Examples -------- .. code-block:: python >>> import brainstate >>> >>> rnn = brainstate.nn.RNNCell(10, 20) >>> brainstate.nn.init_all_states(rnn, batch_size=32) >>> >>> # Process a sequence >>> for x in sequence: ... output = rnn(x) >>> >>> # Reset states before processing next sequence >>> brainstate.nn.reset_all_states(rnn) See Also -------- call_all_functions : The underlying function that executes the calls. vmap_reset_all_states : Vectorized version for batched reset. """ call_all_fns( target, fn_name='reset_state', args=reset_args, kwargs=reset_kwargs, node_to_exclude=node_to_exclude ) return target
[docs] def vmap_reset_all_states( target: T, *reset_args, axis_size: int = None, node_to_exclude: Filter = None, state_tag: str | None = None, **reset_kwargs, ) -> T: """ Reset states with vectorized mapping across batched module instances. This function applies vmap to the reset process, resetting states across all batched instances of the module. Each batch element will have its state reset independently with its own random key. This is useful when working with batched recurrent models or ensembles. Parameters ---------- target : Module The target module whose states are to be reset. reset_args Positional arguments to pass to each `reset_state` method. A single non-tuple argument will be automatically wrapped in a tuple. Default is (). reset_kwargs Keyword arguments to pass to each `reset_state` method. Default is None. axis_size : int The size of the batch dimension. Must be a positive integer. node_to_exclude : Filter, optional A filter to exclude certain nodes from reset. state_tag : str, optional An optional tag to categorize newly created states during the reset. Raises ------ ValueError If `axis_size` is None or not a positive integer. TypeError If `reset_kwargs` is not a mapping. Examples -------- .. code-block:: python >>> import brainstate >>> >>> rnn = brainstate.nn.RNNCell(10, 20) >>> # Initialize with 16 batched instances >>> brainstate.nn.vmap_init_all_states(rnn, batch_size=32, axis_size=16) >>> >>> # Process sequences... >>> >>> # Reset all 16 batched instances >>> brainstate.nn.vmap_reset_all_states(rnn, axis_size=16) See Also -------- reset_all_states : Non-vectorized version. vmap_call_all_functions : The underlying vmap function call mechanism. """ vmap_call_all_fns( target, fn_name='reset_state', args=reset_args, kwargs=reset_kwargs, axis_size=axis_size, node_to_exclude=node_to_exclude, state_tag=state_tag, ) return target
@set_module_as('brainstate.nn') def assign_state_values( target: Module, *state_by_abs_path: Mapping[str, Any] ) -> tuple[list[str], list[str]]: """ Assign state values to a module from one or more state dictionaries. This function updates the state values of a module based on provided state dictionaries. State dictionaries should use absolute paths as keys (e.g., 'layer1.weight', 'layer2.bias'). The function handles missing and unexpected keys, returning them for inspection. Parameters ---------- target : Module The target module whose states will be updated. *state_by_abs_path : Mapping[str, Any] One or more state dictionaries with absolute path keys mapping to state values. If multiple dictionaries are provided, they will be merged (later dictionaries override earlier ones for duplicate keys). Returns ------- tuple[list[str], list[str]] A tuple of (unexpected_keys, missing_keys): - unexpected_keys: Keys present in the state dictionaries but not in the module - missing_keys: Keys present in the module but not in the state dictionaries Examples -------- .. code-block:: python >>> import brainstate >>> >>> net = brainstate.nn.Linear(10, 20) >>> brainstate.nn.init_all_states(net) >>> >>> # Save state values >>> state_dict = {path: state.value for path, state in net.states().items()} >>> >>> # Later, restore state values >>> unexpected, missing = brainstate.nn.assign_state_values(net, state_dict) >>> print(f"Unexpected keys: {unexpected}") >>> print(f"Missing keys: {missing}") Notes ----- - All values are automatically converted to JAX arrays using `jax.numpy.asarray`. - Only states with matching keys are updated; unexpected and missing keys are returned but do not cause errors. - If multiple dictionaries contain the same key, the last one takes precedence. """ # Merge all state dictionaries all_states = {} for state_dict in state_by_abs_path: all_states.update(state_dict) # Get current module states variables = target.states() keys1 = set(all_states.keys()) keys2 = set(variables.keys()) # Update matching states for key in keys2.intersection(keys1): variables[key].value = jax.numpy.asarray(all_states[key]) # Return mismatched keys unexpected_keys = sorted(keys1 - keys2) missing_keys = sorted(keys2 - keys1) return unexpected_keys, missing_keys