Source code for braintrace._etrace_algorithms.param_dim_vjp

# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from functools import partial
from typing import Callable, Dict, Tuple, Optional, Sequence, Any

import brainstate
import jax
import jax.numpy as jnp
import saiunit as u

from braintrace._etrace_compiler import HiddenParamOpRelation, HiddenGroup
from braintrace._etrace_op import (
    etp_elemwise_p,
    etp_mm_p,
    etp_mv_p,
    ETP_RULES_YW_TO_W,
    ETP_RULES_XY_TO_DW,
    ETP_RULES_INIT_DRTRL,
    is_batched_primitive,
)
from braintrace._misc import etrace_df_key
from braintrace._typing import (
    PyTree,
    WeightID,
    Path,
    DTypeLike,
    ETraceX_Key,
    ETraceDF_Key,
    ETraceWG_Key,
    Hid2WeightJacobian,
    HiddenGroupJacobian,
    dG_Weight,
)
from ._common import (
    _extract_leaf,
    _reset_state_in_a_dict,
    _route_grads_by_path,
    _sum_dim,
    _update_dict,
)
from .base import EligibilityTrace
from .vjp_base import ETraceVjpAlgorithm

__all__ = [
    'ParamDimVjpAlgorithm',
]

# Primitives with an elementwise ``yw_to_w`` rule, i.e. rules of the form
# ``trace * hidden_dim_broadcast``. For these we can replace the nested
# ``vmap(yw_to_w, -1, -1) + sum`` pattern with a single ``einsum`` contraction
# over the hidden-state axis of ``diag`` and ``trace``. Conv / sparse / LoRA
# primitives have non-elementwise rules and stay on the legacy path.
_ELEMENTWISE_YW_PRIMITIVES = (etp_mm_p, etp_mv_p, etp_elemwise_p)


def _cast_to_dtype(tree, dtype):
    """Cast every array leaf of ``tree`` to ``dtype`` (unit-safe; ``None`` -> no-op).

    Used to store the eligibility trace — and the inputs to its update — at a
    reduced precision (e.g. ``bfloat16``). The fast path operates on unitless
    arrays, but the ``is_leaf`` guard keeps the helper correct if a leaf ever
    carries a unit.
    """
    if dtype is None:
        return tree
    return jax.tree.map(lambda a: a.astype(dtype), tree, is_leaf=u.math.is_quantity)


def _init_param_dim_state(
    etrace_bwg: Dict[ETraceWG_Key, brainstate.State],
    relation: HiddenParamOpRelation,
    trace_dtype: Optional[DTypeLike] = None,
):
    """
    Initialize the eligibility trace states for parameter dimensions.

    Traces are stored as ``Dict[str, Array]`` keyed by the primitive's
    trainable-input names (dict-based rule API). When ``trace_dtype`` is set and
    the primitive uses the elementwise fast path (mm/mv/elemwise), the trace is
    allocated at that reduced precision; conv/sparse/LoRA keep native precision.
    """
    group: HiddenGroup
    for group in relation.hidden_groups:
        bwg_key = (id(relation.y_var), group.index)
        if bwg_key in etrace_bwg:
            raise ValueError(f'The relation {bwg_key} has been added. ')
        init_fn = ETP_RULES_INIT_DRTRL[relation.primitive]
        init_val = init_fn(
            relation.x_var,
            relation.y_var,
            relation.trainable_vars,
            group.num_state,
        )
        if not isinstance(init_val, dict):
            raise TypeError(
                f'Primitive {relation.primitive.name} init_drtrl must return a dict; '
                f'got {type(init_val).__name__}.'
            )
        if relation.primitive in _ELEMENTWISE_YW_PRIMITIVES:
            init_val = _cast_to_dtype(init_val, trace_dtype)
        etrace_bwg[bwg_key] = EligibilityTrace(init_val)


def _fast_recurrent_term(primitive, diag, old_bwg, num_state):
    """Closed-form ``D^t * eps^{t-1}`` for primitives with an elementwise
    ``yw_to_w`` rule.

    ``diag`` has shape ``(*varshape, num_state_alpha, num_state_beta)`` with
    ``varshape == y_shape`` for mm/mv/elemwise. ``old_bwg`` is the per-key
    trace dict. For mm/mv the weight trace has shape
    ``(*y_shape, in_features, num_state)`` (batched mm has a leading batch
    axis; mv/elemwise do not); the bias trace has shape
    ``(*y_shape, num_state)`` when present.

    The contraction is
        new[b..., i, k, alpha] = sum_beta diag[b..., k, alpha, beta]
                                       * trace[b..., i, k, beta]
    which maps exactly to ``einsum('...kab,...ikb->...ika')``. For elemwise
    the ``i`` axis disappears and it reduces to ``einsum('...ab,...b->...a')``.

    When ``num_state == 1`` (the common single-state case) both state axes have
    size 1, so the sum over ``beta`` collapses to a single term and the whole
    contraction becomes a broadcast multiply — bit-identical to the einsum but
    with no degenerate ``dot_general``. ``diag[..., 0, 0]`` indexes the hidden
    (``k``) axis.
    """
    if num_state == 1:
        if primitive is etp_elemwise_p:
            # diag[..., 0, :] keeps the size-1 beta axis to align with trace.
            return {'weight': diag[..., 0, :] * old_bwg['weight']}
        d = diag[..., 0, 0]  # (*varshape) ending in the hidden ``k`` axis
        out = {'weight': d[..., None, :, None] * old_bwg['weight']}
        if 'bias' in old_bwg:
            out['bias'] = d[..., None] * old_bwg['bias']
        return out
    if primitive is etp_elemwise_p:
        return {
            'weight': jnp.einsum('...ab,...b->...a', diag, old_bwg['weight']),
        }
    # mm / mv: trace['weight'] has an extra in-features axis before ``out``.
    out = {
        'weight': jnp.einsum('...kab,...ikb->...ika', diag, old_bwg['weight']),
    }
    if 'bias' in old_bwg:
        out['bias'] = jnp.einsum('...kab,...kb->...ka', diag, old_bwg['bias'])
    return out


def _fast_instant_term(primitive, x, df, has_bias):
    """Closed-form ``diag(D_f^t) ⊗ x^t`` for mm/mv/elemwise primitives.

    For mm/mv the instantaneous gradient of ``y = x @ W + b`` w.r.t. ``W``
    is the outer product ``x ⊗ df`` with a ``num_state`` axis tagged on.
    For the elemwise identity op it is simply ``df`` (no ``x`` factor).
    """
    if primitive is etp_elemwise_p:
        return {'weight': df}
    if primitive is etp_mm_p:
        out = {'weight': jnp.einsum('...i,...ka->...ika', x, df)}
    else:  # etp_mv_p — no batch axis
        out = {'weight': jnp.einsum('i,ka->ika', x, df)}
    if has_bias:
        out['bias'] = df
    return out


def _update_param_dim_etrace_scan_fn(
    hist_etrace_vals: Dict[ETraceWG_Key, jax.Array],
    jacobians: Tuple[
        Dict[ETraceX_Key, jax.Array],  # the weight x
        Dict[ETraceDF_Key, jax.Array],  # the weight df
        Sequence[jax.Array],  # the hidden group Jacobians
    ],
    weight_path_to_vals: Dict[Path, PyTree],
    hidden_param_op_relations,
    normalize_matrix_spectrum: bool = False,
    fast_solve: bool = True,
    trace_dtype: Optional[DTypeLike] = None,
):
    """
    Update the eligibility trace values for parameter dimensions.

    This function updates the eligibility trace values for the parameter dimensions
    based on the provided Jacobians and the current mode. It computes the new eligibility
    trace values by applying vector-Jacobian products and incorporating the current
    Jacobian values.

    Args:
        hist_etrace_vals (Dict[ETraceWG_Key, jax.Array]): A dictionary containing
            historical eligibility trace values for the weight gradients, keyed by
            ETraceWG_Key.
        jacobians (Tuple[Dict[ETraceX_Key, jax.Array], Dict[ETraceDF_Key, jax.Array], Sequence[jax.Array]]):
            A tuple containing dictionaries of current Jacobian values for the weight x
            and df, and a sequence of hidden group Jacobians.
        weight_path_to_vals (Dict[Path, PyTree]): A dictionary mapping weight paths to
            their corresponding PyTree values.
        hidden_param_op_relations: A sequence of HiddenParamOpRelation objects representing
            the relationships between hidden parameters and operations.
        mode (brainstate.mixin.Mode): The mode indicating whether batching is enabled.

    Returns:
        Tuple[Dict[ETraceWG_Key, jax.Array], None]: A tuple containing a dictionary of
        updated eligibility trace values for the weight gradients, keyed by ETraceWG_Key,
        and None.
    """
    # --- the data --- #

    #
    # + "hist_etrace_vals" has the following structure:
    #    - key: the weight id, the weight-x jax var, the hidden state var
    #    - value: the batched weight gradients
    #

    # + "hid2weight_jac" has the following structure:
    #    - a dict of weight x gradients
    #       * key: the weight x jax var
    #       * value: the weight x gradients
    #    - a dict of weight y gradients
    #       * key: the tuple of the weight y jax var and the hidden state jax var
    #       * value: the weight y gradients
    #
    etrace_xs_at_t: Dict[ETraceX_Key, jax.Array] = jacobians[0]
    etrace_ys_at_t: Dict[ETraceDF_Key, jax.Array] = jacobians[1]

    #
    # the hidden-to-hidden Jacobians
    #
    hid_group_jacobians: Sequence[jax.Array] = jacobians[2]
    if normalize_matrix_spectrum:
        hid_group_jacobians = [_normalize_matrix_spectrum(diag) for diag in hid_group_jacobians]

    # The etrace weight gradients at the current time step.
    # i.e., The "hist_etrace_vals" at the next time step
    #
    new_etrace_bwg = dict()

    relation: HiddenParamOpRelation
    for relation in hidden_param_op_relations:

        # Build the weights dict the rules consume.
        weights_dict = {
            key: _extract_leaf(
                weight_path_to_vals[relation.trainable_paths[key]],
                relation.trainable_leaf_indices[key],
            )
            for key in relation.trainable_vars
        }

        xy_to_dw_rule = ETP_RULES_XY_TO_DW[relation.primitive]
        yw_to_w_rule = ETP_RULES_YW_TO_W[relation.primitive]
        eqn_params = relation.eqn_params
        is_elemwise = relation.primitive is etp_elemwise_p
        batched = is_batched_primitive(relation.primitive)
        has_bias = eqn_params.get('has_bias', False)
        # Fast path only applies to primitives with elementwise yw_to_w.
        use_fast = fast_solve and (relation.primitive in _ELEMENTWISE_YW_PRIMITIVES)

        if is_elemwise:
            x = None
        else:
            x = etrace_xs_at_t[id(relation.x_var)]

        def _call_xy_to_dw_dict(x_, df_, weights_, _rule=xy_to_dw_rule, _params=eqn_params):
            return _rule(x_, df_, weights_, **_params)

        def _call_yw_to_w_dict(d, trace_, _rule=yw_to_w_rule, _params=eqn_params):
            return _rule(d, trace_, **_params)

        def comp_dw_with_x(x_, df_, _wdict=weights_dict):
            return _call_xy_to_dw_dict(x_, df_, _wdict)

        def _comp_instant_legacy(df_all):
            """Legacy nested-vmap path: vmap xy_to_dw over num_state (and batch)."""

            @partial(jax.vmap, in_axes=-1, out_axes=-1)
            def _inner(df_slice):
                if batched:
                    return jax.vmap(comp_dw_with_x)(x, df_slice)
                return comp_dw_with_x(x, df_slice)

            return _inner(df_all)

        def _comp_recurrent_legacy(diag_, old_bwg_, num_state_):
            """Legacy nested-vmap yw_to_w + sum path."""

            def fn_bwg_pre(d, _old=old_bwg_):
                return jax.tree.map(
                    lambda arr: _sum_dim(arr, axis=-1),
                    jax.vmap(_call_yw_to_w_dict, in_axes=-1, out_axes=-1)(d, _old),
                )

            # num_state == 1 shortcut: squeeze the size-1 alpha axis to skip
            # outer vmap overhead; re-expand at the end.
            if num_state_ == 1:
                d_squeezed = u.math.squeeze(diag_, axis=-2)
                res = fn_bwg_pre(d_squeezed)
                return jax.tree.map(lambda a: u.math.expand_dims(a, axis=-1), res)
            return jax.vmap(fn_bwg_pre, in_axes=-2, out_axes=-1)(diag_)

        group: HiddenGroup
        for group in relation.hidden_groups:

            df = etrace_ys_at_t[etrace_df_key(relation.y, group.index)]

            # Instantaneous term: diag(D_f^t) ⊗ x^t  (Dict[str, Array]).
            # Cast the update inputs to ``trace_dtype`` (no-op when None) so the
            # multiply-add runs in the trace precision and the new trace stays
            # there; Jacobians/learning-signal remain full precision elsewhere.
            if use_fast:
                phg_to_pw = _fast_instant_term(
                    relation.primitive,
                    _cast_to_dtype(x, trace_dtype),
                    _cast_to_dtype(df, trace_dtype),
                    has_bias,
                )
            else:
                phg_to_pw = _comp_instant_legacy(df)
            if normalize_matrix_spectrum:
                phg_to_pw = jax.tree.map(_normalize_vector, phg_to_pw)

            w_key = (id(relation.y_var), group.index)
            diag = hid_group_jacobians[group.index]

            old_bwg = hist_etrace_vals[w_key]  # Dict[str, Array]

            # Recurrent term: D^t · ε^{t-1}.
            if use_fast:
                new_bwg_pre = _fast_recurrent_term(
                    relation.primitive,
                    _cast_to_dtype(diag, trace_dtype),
                    old_bwg,
                    group.num_state,
                )
            else:
                new_bwg_pre = _comp_recurrent_legacy(diag, old_bwg, group.num_state)

            # new_bwg_pre + phg_to_pw per-leaf.
            new_bwg = jax.tree.map(
                u.math.add, new_bwg_pre, phg_to_pw, is_leaf=u.math.is_quantity,
            )
            if normalize_matrix_spectrum:
                new_bwg = jax.tree.map(_normalize_vector, new_bwg)
            new_etrace_bwg[w_key] = new_bwg

    return new_etrace_bwg, None


def _normalize_matrix_spectrum(diag):
    """Branch-free spectral clipping: divide by max(|eigvals|, 1).

    The previous implementation used ``jax.lax.cond`` which blocks XLA
    fusion and serialises into a control-flow region per call. Dividing by
    ``max(max_eigenvalue, 1)`` is semantically identical (the conditional
    only clipped when ``max_eigenvalue > 1``) and stays in a single fused
    kernel with the rest of the update.
    """

    def base_fn(matrix):
        eigenvalues = jnp.linalg.eigvals(matrix)
        max_eigenvalue = jnp.max(jnp.abs(eigenvalues))
        return matrix / jnp.maximum(max_eigenvalue, 1.0)

    fn = base_fn
    for _ in range(diag.ndim - 2):
        fn = jax.vmap(fn)
    return fn(diag)


def _normalize_vector(v):
    """Branch-free magnitude clipping: divide by max(max_abs, 1)."""
    max_elem = jnp.abs(v).max()
    return v / jnp.maximum(max_elem, 1.0)


def _fast_solve_contract(primitive, diag_like, etrace_data, fold_batch=False):
    """Solve-time closed-form contraction for mm/mv/elemwise.

    ``diag_like`` is the dl/dh group gradient with shape ``(*y_shape, num_state)``;
    ``etrace_data`` is the weight-shaped trace dict. The solver computes
    ``sum_alpha diag_like[..., alpha] * yw_to_w(etrace[..., alpha])``, which
    for elementwise ``yw_to_w`` is an einsum along the ``num_state`` axis.

    When ``fold_batch`` is True the leading batch axis ``b`` is contracted inside
    the einsum, so the result is the already batch-summed gradient. This avoids
    materializing a ``(B, I, O)`` intermediate and a follow-up ``sum(axis=0)``.
    It assumes exactly one leading batch axis (the same assumption the trailing
    ``sum(axis=0)`` already makes).
    """
    if primitive is etp_elemwise_p:
        spec = 'b...a,b...a->...' if fold_batch else '...a,...a->...'
        return {
            'weight': jnp.einsum(spec, diag_like, etrace_data['weight']),
        }
    w_spec = 'bka,bika->ik' if fold_batch else '...ka,...ika->...ik'
    out = {
        'weight': jnp.einsum(w_spec, diag_like, etrace_data['weight']),
    }
    if 'bias' in etrace_data:
        b_spec = 'bka,bka->k' if fold_batch else '...ka,...ka->...k'
        out['bias'] = jnp.einsum(b_spec, diag_like, etrace_data['bias'])
    return out


def _solve_param_dim_weight_gradients(
    hist_etrace_data: Dict[ETraceWG_Key, PyTree],  # the history etrace data
    dG_weights: Dict[Path, dG_Weight],  # weight gradients
    dG_hidden_groups: Sequence[jax.Array],  # hidden group gradients
    weight_hidden_relations: Sequence[HiddenParamOpRelation],
    weight_vals: Dict[Path, PyTree],  # current ParamState pytree values for structure
    fast_solve: bool = True,
):
    """
    Compute and update the weight gradients for parameter dimensions using eligibility trace data.

    This function calculates the weight gradients by utilizing the eligibility trace data and the
    hidden-to-hidden Jacobians. It applies a correction factor to avoid exponential smoothing bias
    at the beginning of the computation.

    Args:
        hist_etrace_data (Dict[ETraceWG_Key, PyTree]): A dictionary containing historical eligibility
            trace data for the weight gradients, keyed by ETraceWG_Key.
        dG_weights (Dict[Path, dG_Weight]): A dictionary to store the computed weight gradients,
            keyed by the path of the weight.
        dG_hidden_groups (Sequence[jax.Array]): A sequence of hidden group gradients, with the same
            length as the total number of hidden groups.
        weight_hidden_relations (Sequence[HiddenParamOpRelation]): A sequence of HiddenParamOpRelation
            objects representing the relationships between hidden parameters and operations.
        mode (brainstate.mixin.Mode): The mode indicating whether batching is enabled.

    Returns:
        None: The function updates the dG_weights dictionary in place with the computed weight gradients.
    """
    # update the etrace weight gradients
    temp_data: Dict[Path, PyTree] = dict()
    # Paths whose gradient was already batch-reduced inside the fast-path einsum
    # (fold_batch). The trailing batch-sum must skip these.
    folded_paths: set = set()
    for relation in weight_hidden_relations:
        yw_to_w_rule = ETP_RULES_YW_TO_W[relation.primitive]
        eqn_params = relation.eqn_params
        batched = is_batched_primitive(relation.primitive)
        use_fast = fast_solve and (relation.primitive in _ELEMENTWISE_YW_PRIMITIVES)

        def _call_yw_to_w_dict(d, trace_, _rule=yw_to_w_rule, _params=eqn_params):
            return _rule(d, trace_, **_params)

        yw_to_w = (
            jax.vmap(_call_yw_to_w_dict)
            if batched
            else _call_yw_to_w_dict
        )

        group: HiddenGroup
        for group in relation.hidden_groups:

            w_key = (id(relation.y_var), group.index)
            etrace_data = hist_etrace_data[w_key]  # Dict[str, Array]
            dg_hidden = dG_hidden_groups[group.index]

            # dimensionless processing (unit strip + restore). Apply per-leaf.
            etrace_data_unitless, fn_unit_restore = _remove_units(etrace_data)
            dg_hidden_unitless, _ = _remove_units(dg_hidden)

            if use_fast:
                # Upcast a reduced-precision trace to (at least) the learning-
                # signal dtype so the gradient reduction accumulates in full
                # precision. ``promote_types`` never downcasts, so this is a
                # no-op for the default fp32 trace.
                sig_dtype = jax.tree.leaves(dg_hidden_unitless)[0].dtype
                etrace_for_solve = jax.tree.map(
                    lambda a: a.astype(jnp.promote_types(a.dtype, sig_dtype)),
                    etrace_data_unitless,
                )
                # Closed-form einsum path for mm/mv/elemwise primitives. For a
                # batched primitive, fold the batch reduction into the einsum so
                # no (B, I, O) intermediate is materialized; record the routed
                # paths so the trailing batch-sum skips them (already reduced).
                dg_weight_dict = _fast_solve_contract(
                    relation.primitive, dg_hidden_unitless, etrace_for_solve,
                    fold_batch=batched,
                )
                if batched:
                    folded_paths.update(relation.trainable_paths.values())
            elif group.num_state == 1:
                # num_state==1 shortcut: skip outer vmap of size 1.
                dg_hid_squeezed = jax.tree.map(
                    lambda a: u.math.squeeze(a, axis=-1), dg_hidden_unitless
                )
                etr_squeezed = jax.tree.map(
                    lambda a: u.math.squeeze(a, axis=-1), etrace_data_unitless
                )
                dg_weight_dict = yw_to_w(dg_hid_squeezed, etr_squeezed)
            else:
                dg_weight_dict = jax.tree.map(
                    lambda arr: _sum_dim(arr, axis=-1),
                    jax.vmap(yw_to_w, in_axes=-1, out_axes=-1)(
                        dg_hidden_unitless, etrace_data_unitless
                    ),
                )
            dg_weight_dict = fn_unit_restore(dg_weight_dict)

            # Route per-key to owning ParamState path.
            _route_grads_by_path(relation, dg_weight_dict, weight_vals, temp_data)

    #
    # Step 3:
    #
    # sum up the batched weight gradients
    # Check if ANY relation uses a batched primitive
    has_batched = any(is_batched_primitive(r.primitive) for r in weight_hidden_relations)
    if has_batched:
        for key, val in temp_data.items():
            if key in folded_paths:
                # already batch-reduced inside the fast-path einsum (fold_batch)
                continue
            temp_data[key] = jax.tree.map(lambda x: u.math.sum(x, axis=0), val)

    # update the weight gradients
    for key, val in temp_data.items():
        _update_dict(dG_weights, key, val)


def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree):
    """
    Removes units from a PyTree of quantities, returning a unitless PyTree and a function to restore the units.

    This function traverses a PyTree structure, removing units from each quantity and returning a new PyTree
    with the same structure but without units. It also returns a function that can be used to restore the
    original units to the unitless PyTree.

    Args:
        xs_maybe_quantity (brainstate.typing.PyTree): A PyTree structure containing quantities with units.

    Returns:
        Tuple[brainstate.typing.PyTree, Callable]: A tuple containing:
            - A PyTree with the same structure as the input, but with units removed from each quantity.
            - A function that takes a unitless PyTree and restores the original units to it.
    """
    leaves, treedef = jax.tree.flatten(xs_maybe_quantity, is_leaf=u.math.is_quantity)
    new_leaves, units = [], []
    for leaf in leaves:
        leaf, unit = u.split_mantissa_unit(leaf)
        new_leaves.append(leaf)
        units.append(unit)

    def restore_units(xs_unitless: brainstate.typing.PyTree):
        leaves, treedef2 = jax.tree.flatten(xs_unitless)
        # jax's PyTreeDef stubs omit __eq__; the comparison is valid at runtime.
        assert treedef == treedef2, 'The tree structure should be the same. '  # type: ignore[operator]
        new_leaves = [
            leaf if unit.dim.is_dimensionless else leaf * unit
            for leaf, unit in zip(leaves, units)
        ]
        return jax.tree.unflatten(treedef, new_leaves)

    return jax.tree.unflatten(treedef, new_leaves), restore_units


[docs] class ParamDimVjpAlgorithm(ETraceVjpAlgorithm): r"""Online gradient algorithm with diagonal approximation and parameter-dimension complexity. This algorithm computes the gradients of the weights with the diagonal approximation and the parameter-dimension complexity. It is based on the RTRL algorithm (Real-Time Recurrent Learning). Parameters ---------- model : brainstate.nn.Module The model function, which receives the input arguments and returns the model output. vjp_method : str, optional The method for computing the VJP. It should be either ``"single-step"`` or ``"multi-step"``. - ``"single-step"``: the VJP is computed at the current time step, i.e., :math:`\partial L^t/\partial h^t`. - ``"multi-step"``: the VJP is computed at multiple time steps, i.e., :math:`\partial L^t/\partial h^{t-k}`, where :math:`k` is determined by the data input. name : str, optional The name of the etrace algorithm. mode : braintrace.mixin.Mode, optional The computing mode, indicating the batching behavior. Notes ----- The learning rule is .. math:: \begin{aligned} &\boldsymbol{\epsilon}^t \approx \mathbf{D}^t \boldsymbol{\epsilon}^{t-1}+\operatorname{diag}\left(\mathbf{D}_f^t\right) \otimes \mathbf{x}^t \\ & \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}} \end{aligned} where :math:`\boldsymbol{\epsilon}^t` is the per-parameter eligibility trace, :math:`\mathbf{D}^t` the hidden-to-hidden Jacobian, :math:`\mathbf{D}_f^t` the state-to-output Jacobian, :math:`\mathbf{x}^t` the presynaptic input, and :math:`\partial \mathcal{L}^{t'}/\partial \mathbf{h}^{t'}` the learning signal back-propagated from the loss at each step. Real-Time Recurrent Learning (RTRL) propagates the full sensitivity :math:`\partial \mathbf{h}^t/\partial \boldsymbol{\theta}` forward in time, which costs :math:`O(|\theta| \cdot H)` memory. D-RTRL keeps only the *diagonal* of the hidden-to-hidden Jacobian, collapsing the trace to one value per parameter. The trace is then contracted with the instantaneous learning signal at each step to accumulate the gradient — no backward pass through time and memory linear in the parameter count. :class:`ParamDimVjpAlgorithm` is a subclass of :class:`brainstate.nn.Module` and is sensitive to the context/mode of the computation. In particular, it is sensitive to ``brainstate.mixin.Batching`` behavior. This algorithm has :math:`O(B\theta)` memory complexity, where :math:`\theta` is the number of parameters and :math:`B` the batch size. For a convolutional layer, the weight gradients are computed with :math:`O(B\theta)` memory complexity, where :math:`\theta` is the dimension of the convolutional kernel. For a linear transformation layer, the weight gradients are computed with :math:`O(BIO)` computational complexity, where :math:`I` and :math:`O` are the number of input and output dimensions. For more details, please see `the D-RTRL algorithm presented in our manuscript <https://www.biorxiv.org/content/10.1101/2024.09.24.614728v2>`_. Examples -------- .. code-block:: python >>> import brainstate >>> import braintrace >>> >>> class RNN(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh') ... self.out = braintrace.nn.Linear(20, 1) ... def update(self, x): ... return x >> self.cell >> self.out >>> >>> model = RNN() >>> _ = brainstate.nn.init_all_states(model) >>> learner = braintrace.D_RTRL(model) # alias of ParamDimVjpAlgorithm >>> x0 = brainstate.random.randn(1) >>> learner.compile_graph(x0) # trace the graph once >>> y = learner(x0) # forward pass + eligibility-trace update References ---------- .. [1] Wang, C., Dong, X., Ji, Z., Xiao, M., Jiang, J., Liu, X., Huan, Y., & Wu, S. (2026). "Model-agnostic linear-memory online learning in spiking neural networks." *Nature Communications*. https://doi.org/10.1038/s41467-026-68453-w (preprint: bioRxiv 2024.09.24.614728) .. [2] Williams, R. J., & Zipser, D. (1989). "A Learning Algorithm for Continually Running Fully Recurrent Neural Networks" (RTRL). *Neural Computation*, 1(2), 270-280. https://doi.org/10.1162/neco.1989.1.2.270 """ # batch of weight gradients etrace_bwg: Dict[ETraceWG_Key, brainstate.State] def __init__( self, model: brainstate.nn.Module, name: Optional[str] = None, vjp_method: str = 'single-step', fast_solve: bool = True, normalize_matrix_spectrum: bool = False, trace_dtype: Optional[DTypeLike] = None, **kwargs, ): super().__init__(model, name=name, vjp_method=vjp_method) # ``fast_solve=True`` enables closed-form einsum kernels for # mm/mv/elemwise primitives, replacing the nested-vmap legacy path. # Conv / sparse / LoRA primitives always use the legacy path. self.fast_solve = fast_solve # When True, clip trace magnitudes > 1 each step via a branch-free # ``v / max(max_abs, 1)``. Default False (disabled). The previous # implementation applied this unconditionally to the instantaneous # term only, which silently distorted gradients. self.normalize_matrix_spectrum = normalize_matrix_spectrum # Optional reduced-precision storage for the eligibility trace (e.g. # ``jnp.bfloat16`` / ``jnp.float16``); ``None`` keeps native fp32. Only # the mm/mv/elemwise fast path honors it. Reduced-precision eigenvalues # are unreliable, so this is incompatible with spectral normalization. if ( trace_dtype is not None and normalize_matrix_spectrum and jnp.issubdtype(jnp.dtype(trace_dtype), jnp.floating) and jnp.dtype(trace_dtype).itemsize < 4 ): raise ValueError( f'trace_dtype={trace_dtype!r} has reduced precision (<4 bytes), ' 'which is incompatible with normalize_matrix_spectrum=True: ' 'spectral normalization needs full-precision eigenvalues. Use ' 'trace_dtype=None (or float32) with normalization, or disable it.' ) self.trace_dtype = trace_dtype
[docs] def init_etrace_state(self, *args, **kwargs): """Initialize the eligibility trace states of the etrace algorithm. This method is needed after compiling the etrace graph. See :meth:`compile_graph` for the details. """ # The states of batched weight gradients self.etrace_bwg = dict() for relation in self.graph.hidden_param_op_relations: _init_param_dim_state(self.etrace_bwg, relation, self.trace_dtype)
[docs] def reset_state(self, batch_size: int = None, **kwargs): """Reset the eligibility trace states. Parameters ---------- batch_size : int, optional The batch size used to reshape the reset trace states. Default ``None``. """ self.running_index.value = 0 _reset_state_in_a_dict(self.etrace_bwg, batch_size)
[docs] def get_etrace_of(self, weight: brainstate.ParamState | Path) -> Dict: """Get the eligibility trace of the given weight. Parameters ---------- weight : brainstate.ParamState or Path The weight whose eligibility trace is requested, given either as a :class:`brainstate.ParamState` instance or as its path in the model. Returns ------- dict A dictionary mapping ``(y_var id, hidden-group index)`` keys to the eligibility-trace values associated with the given weight. Raises ------ ValueError If no eligibility trace is found for the given weight. """ self._assert_compiled() # get the wight id weight_id = ( id(weight) if isinstance(weight, brainstate.ParamState) else id(self.graph_executor.path_to_states[weight]) ) find_this_weight = False etraces = dict() relation: HiddenParamOpRelation for relation in self.graph.hidden_param_op_relations: primary_state = next(iter(relation.trainable_param_states.values()), None) if primary_state is None or id(primary_state) != weight_id: continue find_this_weight = True # retrieve the etrace data group: HiddenGroup for group in relation.hidden_groups: key = (id(relation.y_var), group.index) etraces[key] = self.etrace_bwg[key].value if not find_this_weight: raise ValueError(f'Do not the etrace of the given weight: {weight}.') return etraces
def _get_etrace_data(self) -> Dict: """Retrieve the current eligibility trace data from all trace states. This method collects all eligibility trace values from the internal state dictionary, extracting the current values from the brainstate.State objects that store them. It returns these values in a dictionary with the same keys as the original state dictionary, making the current trace values available for processing. This is an internal method used in the parameter dimension eligibility trace algorithm to access the current trace state for updates and gradient calculations. Returns: Dict[ETraceWG_Key, jax.Array]: A dictionary mapping eligibility trace keys to their current values. Each key represents a specific trace component (typically involving a parameter and hidden state relationship), and the corresponding value represents the accumulated eligibility trace. """ return { k: v.value for k, v in self.etrace_bwg.items() } def _assign_etrace_data(self, etrace_vals: Dict) -> None: """Assign eligibility trace values to their corresponding state objects. This method updates the internal eligibility trace state dictionary (etrace_bwg) with new values from the provided dictionary. It iterates through each key-value pair in the input dictionary and assigns the value to the corresponding state object's value attribute. This is an implementation of the abstract method from the parent class, customized for the parameter dimension eligibility trace algorithm which stores traces in a single dictionary rather than separate ones for inputs and differential functions. Args: etrace_vals: Dict[ETraceWG_Key, jax.Array] Dictionary mapping eligibility trace keys to their updated values. Each key represents a specific parameter-hidden state relationship, and the value represents the updated eligibility trace value. Returns: None """ for x, val in etrace_vals.items(): self.etrace_bwg[x].value = val def _make_etrace_stepper(self, weight_vals: Dict[Path, PyTree]) -> Callable: """Build the per-step D-RTRL eligibility-trace stepper. Returns the ``partial`` of :func:`_update_param_dim_etrace_scan_fn` that serves as the body of the trace scan. Exposing it lets the graph executor fuse the roll into its over-time scan for multi-step input (see the base-class :meth:`_make_etrace_stepper`). """ return partial( _update_param_dim_etrace_scan_fn, weight_path_to_vals=weight_vals, hidden_param_op_relations=self.graph.hidden_param_op_relations, normalize_matrix_spectrum=self.normalize_matrix_spectrum, fast_solve=self.fast_solve, trace_dtype=self.trace_dtype, ) def _update_etrace_data( self, running_index: Optional[int], hist_etrace_vals: Dict[ETraceWG_Key, PyTree], hid2weight_jac_single_or_multi_times: Hid2WeightJacobian, hid2hid_jac_single_or_multi_times: HiddenGroupJacobian, weight_vals: Dict[Path, PyTree], input_is_multi_step: bool, ) -> Dict[ETraceWG_Key, PyTree]: """Update eligibility trace data for the parameter dimension-based algorithm. This method implements the core update equation for the D-RTRL algorithm's eligibility traces: ε^t ≈ D^t·ε^{t-1} + diag(D_f^t)⊗x^t It uses JAX's scan operation to efficiently process the historical trace values and combines them with current Jacobians to compute updated traces according to the parameter-dimension approximation approach. Args: running_index: Optional[int] Current timestep counter, used for correcting exponential smoothing bias. hist_etrace_vals: Dict[ETraceWG_Key, PyTree] Dictionary containing historical eligibility trace values from previous timestep. Keys are tuples identifying parameter-hidden state relationships. hid2weight_jac_single_or_multi_times: Hid2WeightJacobian Jacobians of hidden states with respect to weights at the current timestep. Contains input gradients and differential function gradients. hid2hid_jac_single_or_multi_times: HiddenGroupJacobian Jacobians between hidden states (recurrent connections) at the current timestep. weight_vals: Dict[Path, PyTree] Dictionary mapping paths to current weight values in the model. Returns: Dict[ETraceWG_Key, PyTree]: Updated eligibility trace values dictionary with the same structure as hist_etrace_vals but containing new values for the current timestep. """ scan_fn = self._make_etrace_stepper(weight_vals) if input_is_multi_step: new_etrace = jax.lax.scan( scan_fn, hist_etrace_vals, ( hid2weight_jac_single_or_multi_times[0], hid2weight_jac_single_or_multi_times[1], hid2hid_jac_single_or_multi_times, ) )[0] else: new_etrace = scan_fn( hist_etrace_vals, ( hid2weight_jac_single_or_multi_times[0], hid2weight_jac_single_or_multi_times[1], hid2hid_jac_single_or_multi_times, ) )[0] return new_etrace def _solve_weight_gradients( self, running_index: int, etrace_h2w_at_t: Dict[ETraceWG_Key, PyTree], dl_to_hidden_groups: Sequence[jax.Array], weight_vals: Dict[Path, PyTree], dl_to_nonetws_at_t: Dict[Path, PyTree], dl_to_etws_at_t: Optional[Dict[Path, PyTree]], ): """Compute weight gradients using parameter dimension eligibility traces. This method implements the parameter dimension D-RTRL algorithm's weight gradient computation. It combines the eligibility traces with the gradients of the loss with respect to hidden states to compute the full parameter gradients according to: ∇_θ L = ∑_{t' ∈ T} ∂L^{t'}/∂h^{t'} ∘ ε^{t'} Where ε represents the eligibility traces and ∂L/∂h are the gradients of the loss with respect to hidden states. Args: running_index: int Current timestep counter used for bias correction. etrace_h2w_at_t: Dict[ETraceWG_Key, PyTree] Eligibility trace values at the current timestep, mapping parameter-hidden state relationship keys to trace values. dl_to_hidden_groups: Sequence[jax.Array] Gradients of the loss with respect to hidden states at the current timestep. weight_vals: Dict[WeightID, PyTree] Current values of all weights in the model. dl_to_nonetws_at_t: Dict[Path, PyTree] Gradients of non-eligibility trace parameters at the current timestep. dl_to_etws_at_t: Optional[Dict[Path, PyTree]] Optional additional gradients for eligibility trace parameters at the current timestep. Returns: Dict[Path, PyTree]: Dictionary mapping parameter paths to their gradient values. """ dG_weights: Dict[Path, Any] = {path: None for path in self.param_states} # update the etrace weight gradients _solve_param_dim_weight_gradients( etrace_h2w_at_t, dG_weights, dl_to_hidden_groups, self.graph.hidden_param_op_relations, weight_vals, fast_solve=self.fast_solve, ) # update the non-etrace weight gradients for path, dg in dl_to_nonetws_at_t.items(): _update_dict(dG_weights, path, dg) # update the etrace parameters when "dl_to_etws_at_t" is not None if dl_to_etws_at_t is not None: for path, dg in dl_to_etws_at_t.items(): _update_dict(dG_weights, path, dg, error_when_no_key=True) return dG_weights