# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from functools import partial
from typing import Dict, Tuple, Optional, Sequence
import brainstate
import jax
import jax.numpy as jnp
import saiunit as u
from braintrace._etrace_compiler import HiddenParamOpRelation, HiddenGroup
from braintrace._etrace_op import (
etp_elemwise_p,
etp_mm_p,
etp_mv_p,
ETP_RULES_YW_TO_W,
ETP_RULES_XY_TO_DW,
ETP_RULES_INIT_DRTRL,
is_batched_primitive,
)
from braintrace._misc import etrace_df_key
from braintrace._typing import (
PyTree,
WeightID,
Path,
ETraceX_Key,
ETraceDF_Key,
ETraceWG_Key,
Hid2WeightJacobian,
HiddenGroupJacobian,
dG_Weight,
)
from .base import EligibilityTrace
from .misc import (
_extract_leaf,
_reset_state_in_a_dict,
_route_grads_by_path,
_sum_dim,
_update_dict,
)
from .vjp_base import ETraceVjpAlgorithm
__all__ = [
'ParamDimVjpAlgorithm',
'D_RTRL',
]
# Primitives with an elementwise ``yw_to_w`` rule, i.e. rules of the form
# ``trace * hidden_dim_broadcast``. For these we can replace the nested
# ``vmap(yw_to_w, -1, -1) + sum`` pattern with a single ``einsum`` contraction
# over the hidden-state axis of ``diag`` and ``trace``. Conv / sparse / LoRA
# primitives have non-elementwise rules and stay on the legacy path.
_ELEMENTWISE_YW_PRIMITIVES = (etp_mm_p, etp_mv_p, etp_elemwise_p)
def _init_param_dim_state(
etrace_bwg: Dict[ETraceWG_Key, brainstate.State],
relation: HiddenParamOpRelation
):
"""
Initialize the eligibility trace states for parameter dimensions.
Traces are stored as ``Dict[str, Array]`` keyed by the primitive's
trainable-input names (dict-based rule API).
"""
for group in relation.hidden_groups:
group: HiddenGroup
bwg_key = (id(relation.y_var), group.index)
if bwg_key in etrace_bwg:
raise ValueError(f'The relation {bwg_key} has been added. ')
init_fn = ETP_RULES_INIT_DRTRL[relation.primitive]
init_val = init_fn(
relation.x_var,
relation.y_var,
relation.trainable_vars,
group.num_state,
)
if not isinstance(init_val, dict):
raise TypeError(
f'Primitive {relation.primitive.name} init_drtrl must return a dict; '
f'got {type(init_val).__name__}.'
)
etrace_bwg[bwg_key] = EligibilityTrace(init_val)
def _fast_recurrent_term(primitive, diag, old_bwg):
"""Closed-form ``D^t * eps^{t-1}`` for primitives with an elementwise
``yw_to_w`` rule.
``diag`` has shape ``(*varshape, num_state_alpha, num_state_beta)`` with
``varshape == y_shape`` for mm/mv/elemwise. ``old_bwg`` is the per-key
trace dict. For mm/mv the weight trace has shape
``(*y_shape, in_features, num_state)`` (batched mm has a leading batch
axis; mv/elemwise do not); the bias trace has shape
``(*y_shape, num_state)`` when present.
The contraction is
new[b..., i, k, alpha] = sum_beta diag[b..., k, alpha, beta]
* trace[b..., i, k, beta]
which maps exactly to ``einsum('...kab,...ikb->...ika')``. For elemwise
the ``i`` axis disappears and it reduces to ``einsum('...ab,...b->...a')``.
"""
if primitive is etp_elemwise_p:
return {
'weight': jnp.einsum('...ab,...b->...a', diag, old_bwg['weight']),
}
# mm / mv: trace['weight'] has an extra in-features axis before ``out``.
out = {
'weight': jnp.einsum('...kab,...ikb->...ika', diag, old_bwg['weight']),
}
if 'bias' in old_bwg:
out['bias'] = jnp.einsum('...kab,...kb->...ka', diag, old_bwg['bias'])
return out
def _fast_instant_term(primitive, x, df, has_bias):
"""Closed-form ``diag(D_f^t) ⊗ x^t`` for mm/mv/elemwise primitives.
For mm/mv the instantaneous gradient of ``y = x @ W + b`` w.r.t. ``W``
is the outer product ``x ⊗ df`` with a ``num_state`` axis tagged on.
For the elemwise identity op it is simply ``df`` (no ``x`` factor).
"""
if primitive is etp_elemwise_p:
return {'weight': df}
if primitive is etp_mm_p:
out = {'weight': jnp.einsum('...i,...ka->...ika', x, df)}
else: # etp_mv_p — no batch axis
out = {'weight': jnp.einsum('i,ka->ika', x, df)}
if has_bias:
out['bias'] = df
return out
def _update_param_dim_etrace_scan_fn(
hist_etrace_vals: Dict[ETraceWG_Key, jax.Array],
jacobians: Tuple[
Dict[ETraceX_Key, jax.Array], # the weight x
Dict[ETraceDF_Key, jax.Array], # the weight df
Sequence[jax.Array], # the hidden group Jacobians
],
weight_path_to_vals: Dict[Path, PyTree],
hidden_param_op_relations,
normalize_matrix_spectrum: bool = False,
fast_solve: bool = True,
):
"""
Update the eligibility trace values for parameter dimensions.
This function updates the eligibility trace values for the parameter dimensions
based on the provided Jacobians and the current mode. It computes the new eligibility
trace values by applying vector-Jacobian products and incorporating the current
Jacobian values.
Args:
hist_etrace_vals (Dict[ETraceWG_Key, jax.Array]): A dictionary containing
historical eligibility trace values for the weight gradients, keyed by
ETraceWG_Key.
jacobians (Tuple[Dict[ETraceX_Key, jax.Array], Dict[ETraceDF_Key, jax.Array], Sequence[jax.Array]]):
A tuple containing dictionaries of current Jacobian values for the weight x
and df, and a sequence of hidden group Jacobians.
weight_path_to_vals (Dict[Path, PyTree]): A dictionary mapping weight paths to
their corresponding PyTree values.
hidden_param_op_relations: A sequence of HiddenParamOpRelation objects representing
the relationships between hidden parameters and operations.
mode (brainstate.mixin.Mode): The mode indicating whether batching is enabled.
Returns:
Tuple[Dict[ETraceWG_Key, jax.Array], None]: A tuple containing a dictionary of
updated eligibility trace values for the weight gradients, keyed by ETraceWG_Key,
and None.
"""
# --- the data --- #
#
# + "hist_etrace_vals" has the following structure:
# - key: the weight id, the weight-x jax var, the hidden state var
# - value: the batched weight gradients
#
# + "hid2weight_jac" has the following structure:
# - a dict of weight x gradients
# * key: the weight x jax var
# * value: the weight x gradients
# - a dict of weight y gradients
# * key: the tuple of the weight y jax var and the hidden state jax var
# * value: the weight y gradients
#
etrace_xs_at_t: Dict[ETraceX_Key, jax.Array] = jacobians[0]
etrace_ys_at_t: Dict[ETraceDF_Key, jax.Array] = jacobians[1]
#
# the hidden-to-hidden Jacobians
#
hid_group_jacobians: Sequence[jax.Array] = jacobians[2]
if normalize_matrix_spectrum:
hid_group_jacobians = [_normalize_matrix_spectrum(diag) for diag in hid_group_jacobians]
# The etrace weight gradients at the current time step.
# i.e., The "hist_etrace_vals" at the next time step
#
new_etrace_bwg = dict()
for relation in hidden_param_op_relations:
relation: HiddenParamOpRelation
# Build the weights dict the rules consume.
weights_dict = {
key: _extract_leaf(
weight_path_to_vals[relation.trainable_paths[key]],
relation.trainable_leaf_indices[key],
)
for key in relation.trainable_vars
}
xy_to_dw_rule = ETP_RULES_XY_TO_DW[relation.primitive]
yw_to_w_rule = ETP_RULES_YW_TO_W[relation.primitive]
eqn_params = relation.eqn_params
is_elemwise = relation.primitive is etp_elemwise_p
batched = is_batched_primitive(relation.primitive)
has_bias = eqn_params.get('has_bias', False)
# Fast path only applies to primitives with elementwise yw_to_w.
use_fast = fast_solve and (relation.primitive in _ELEMENTWISE_YW_PRIMITIVES)
if is_elemwise:
x = None
else:
x = etrace_xs_at_t[id(relation.x_var)]
def _call_xy_to_dw_dict(x_, df_, weights_, _rule=xy_to_dw_rule, _params=eqn_params):
return _rule(x_, df_, weights_, **_params)
def _call_yw_to_w_dict(d, trace_, _rule=yw_to_w_rule, _params=eqn_params):
return _rule(d, trace_, **_params)
def comp_dw_with_x(x_, df_, _wdict=weights_dict):
return _call_xy_to_dw_dict(x_, df_, _wdict)
def _comp_instant_legacy(df_all):
"""Legacy nested-vmap path: vmap xy_to_dw over num_state (and batch)."""
@partial(jax.vmap, in_axes=-1, out_axes=-1)
def _inner(df_slice):
if batched:
return jax.vmap(comp_dw_with_x)(x, df_slice)
return comp_dw_with_x(x, df_slice)
return _inner(df_all)
def _comp_recurrent_legacy(diag_, old_bwg_, num_state_):
"""Legacy nested-vmap yw_to_w + sum path."""
def fn_bwg_pre(d, _old=old_bwg_):
return jax.tree.map(
lambda arr: _sum_dim(arr, axis=-1),
jax.vmap(_call_yw_to_w_dict, in_axes=-1, out_axes=-1)(d, _old),
)
# num_state == 1 shortcut: squeeze the size-1 alpha axis to skip
# outer vmap overhead; re-expand at the end.
if num_state_ == 1:
d_squeezed = u.math.squeeze(diag_, axis=-2)
res = fn_bwg_pre(d_squeezed)
return jax.tree.map(lambda a: u.math.expand_dims(a, axis=-1), res)
return jax.vmap(fn_bwg_pre, in_axes=-2, out_axes=-1)(diag_)
for group in relation.hidden_groups:
group: HiddenGroup
df = etrace_ys_at_t[etrace_df_key(relation.y, group.index)]
# Instantaneous term: diag(D_f^t) ⊗ x^t (Dict[str, Array]).
if use_fast:
phg_to_pw = _fast_instant_term(relation.primitive, x, df, has_bias)
else:
phg_to_pw = _comp_instant_legacy(df)
if normalize_matrix_spectrum:
phg_to_pw = jax.tree.map(_normalize_vector, phg_to_pw)
w_key = (id(relation.y_var), group.index)
diag = hid_group_jacobians[group.index]
old_bwg = hist_etrace_vals[w_key] # Dict[str, Array]
# Recurrent term: D^t · ε^{t-1}.
if use_fast:
new_bwg_pre = _fast_recurrent_term(relation.primitive, diag, old_bwg)
else:
new_bwg_pre = _comp_recurrent_legacy(diag, old_bwg, group.num_state)
# new_bwg_pre + phg_to_pw per-leaf.
new_bwg = jax.tree.map(
u.math.add, new_bwg_pre, phg_to_pw, is_leaf=u.math.is_quantity,
)
if normalize_matrix_spectrum:
new_bwg = jax.tree.map(_normalize_vector, new_bwg)
new_etrace_bwg[w_key] = new_bwg
return new_etrace_bwg, None
def _normalize_matrix_spectrum(diag):
"""Branch-free spectral clipping: divide by max(|eigvals|, 1).
The previous implementation used ``jax.lax.cond`` which blocks XLA
fusion and serialises into a control-flow region per call. Dividing by
``max(max_eigenvalue, 1)`` is semantically identical (the conditional
only clipped when ``max_eigenvalue > 1``) and stays in a single fused
kernel with the rest of the update.
"""
def base_fn(matrix):
eigenvalues = jnp.linalg.eigvals(matrix)
max_eigenvalue = jnp.max(jnp.abs(eigenvalues))
return matrix / jnp.maximum(max_eigenvalue, 1.0)
fn = base_fn
for _ in range(diag.ndim - 2):
fn = jax.vmap(fn)
return fn(diag)
def _normalize_vector(v):
"""Branch-free magnitude clipping: divide by max(max_abs, 1)."""
max_elem = jnp.abs(v).max()
return v / jnp.maximum(max_elem, 1.0)
def _fast_solve_contract(primitive, diag_like, etrace_data):
"""Solve-time closed-form contraction for mm/mv/elemwise.
``diag_like`` is the dl/dh group gradient with shape ``(*y_shape, num_state)``;
``etrace_data`` is the weight-shaped trace dict. The solver computes
``sum_alpha diag_like[..., alpha] * yw_to_w(etrace[..., alpha])``, which
for elementwise ``yw_to_w`` is an einsum along the ``num_state`` axis.
"""
if primitive is etp_elemwise_p:
return {
'weight': jnp.einsum('...a,...a->...', diag_like, etrace_data['weight']),
}
out = {
'weight': jnp.einsum('...ka,...ika->...ik', diag_like, etrace_data['weight']),
}
if 'bias' in etrace_data:
out['bias'] = jnp.einsum('...ka,...ka->...k', diag_like, etrace_data['bias'])
return out
def _solve_param_dim_weight_gradients(
hist_etrace_data: Dict[ETraceWG_Key, PyTree], # the history etrace data
dG_weights: Dict[Path, dG_Weight], # weight gradients
dG_hidden_groups: Sequence[jax.Array], # hidden group gradients
weight_hidden_relations: Sequence[HiddenParamOpRelation],
weight_vals: Dict[Path, PyTree], # current ParamState pytree values for structure
fast_solve: bool = True,
):
"""
Compute and update the weight gradients for parameter dimensions using eligibility trace data.
This function calculates the weight gradients by utilizing the eligibility trace data and the
hidden-to-hidden Jacobians. It applies a correction factor to avoid exponential smoothing bias
at the beginning of the computation.
Args:
hist_etrace_data (Dict[ETraceWG_Key, PyTree]): A dictionary containing historical eligibility
trace data for the weight gradients, keyed by ETraceWG_Key.
dG_weights (Dict[Path, dG_Weight]): A dictionary to store the computed weight gradients,
keyed by the path of the weight.
dG_hidden_groups (Sequence[jax.Array]): A sequence of hidden group gradients, with the same
length as the total number of hidden groups.
weight_hidden_relations (Sequence[HiddenParamOpRelation]): A sequence of HiddenParamOpRelation
objects representing the relationships between hidden parameters and operations.
mode (brainstate.mixin.Mode): The mode indicating whether batching is enabled.
Returns:
None: The function updates the dG_weights dictionary in place with the computed weight gradients.
"""
# update the etrace weight gradients
temp_data: Dict[Path, PyTree] = dict()
for relation in weight_hidden_relations:
yw_to_w_rule = ETP_RULES_YW_TO_W[relation.primitive]
eqn_params = relation.eqn_params
batched = is_batched_primitive(relation.primitive)
use_fast = fast_solve and (relation.primitive in _ELEMENTWISE_YW_PRIMITIVES)
def _call_yw_to_w_dict(d, trace_, _rule=yw_to_w_rule, _params=eqn_params):
return _rule(d, trace_, **_params)
yw_to_w = (
jax.vmap(_call_yw_to_w_dict)
if batched
else _call_yw_to_w_dict
)
for group in relation.hidden_groups:
group: HiddenGroup
w_key = (id(relation.y_var), group.index)
etrace_data = hist_etrace_data[w_key] # Dict[str, Array]
dg_hidden = dG_hidden_groups[group.index]
# dimensionless processing (unit strip + restore). Apply per-leaf.
etrace_data_unitless, fn_unit_restore = _remove_units(etrace_data)
dg_hidden_unitless, _ = _remove_units(dg_hidden)
if use_fast:
# Closed-form einsum path for mm/mv/elemwise primitives.
dg_weight_dict = _fast_solve_contract(
relation.primitive, dg_hidden_unitless, etrace_data_unitless
)
elif group.num_state == 1:
# num_state==1 shortcut: skip outer vmap of size 1.
dg_hid_squeezed = jax.tree.map(
lambda a: u.math.squeeze(a, axis=-1), dg_hidden_unitless
)
etr_squeezed = jax.tree.map(
lambda a: u.math.squeeze(a, axis=-1), etrace_data_unitless
)
dg_weight_dict = yw_to_w(dg_hid_squeezed, etr_squeezed)
else:
dg_weight_dict = jax.tree.map(
lambda arr: _sum_dim(arr, axis=-1),
jax.vmap(yw_to_w, in_axes=-1, out_axes=-1)(
dg_hidden_unitless, etrace_data_unitless
),
)
dg_weight_dict = fn_unit_restore(dg_weight_dict)
# Route per-key to owning ParamState path.
_route_grads_by_path(relation, dg_weight_dict, weight_vals, temp_data)
#
# Step 3:
#
# sum up the batched weight gradients
# Check if ANY relation uses a batched primitive
has_batched = any(is_batched_primitive(r.primitive) for r in weight_hidden_relations)
if has_batched:
for key, val in temp_data.items():
temp_data[key] = jax.tree.map(lambda x: u.math.sum(x, axis=0), val)
# update the weight gradients
for key, val in temp_data.items():
_update_dict(dG_weights, key, val)
def _remove_units(xs_maybe_quantity: brainstate.typing.PyTree):
"""
Removes units from a PyTree of quantities, returning a unitless PyTree and a function to restore the units.
This function traverses a PyTree structure, removing units from each quantity and returning a new PyTree
with the same structure but without units. It also returns a function that can be used to restore the
original units to the unitless PyTree.
Args:
xs_maybe_quantity (brainstate.typing.PyTree): A PyTree structure containing quantities with units.
Returns:
Tuple[brainstate.typing.PyTree, Callable]: A tuple containing:
- A PyTree with the same structure as the input, but with units removed from each quantity.
- A function that takes a unitless PyTree and restores the original units to it.
"""
leaves, treedef = jax.tree.flatten(xs_maybe_quantity, is_leaf=u.math.is_quantity)
new_leaves, units = [], []
for leaf in leaves:
leaf, unit = u.split_mantissa_unit(leaf)
new_leaves.append(leaf)
units.append(unit)
def restore_units(xs_unitless: brainstate.typing.PyTree):
leaves, treedef2 = jax.tree.flatten(xs_unitless)
assert treedef == treedef2, 'The tree structure should be the same. '
new_leaves = [
leaf if unit.dim.is_dimensionless else leaf * unit
for leaf, unit in zip(leaves, units)
]
return jax.tree.unflatten(treedef, new_leaves)
return jax.tree.unflatten(treedef, new_leaves), restore_units
[docs]
class ParamDimVjpAlgorithm(ETraceVjpAlgorithm):
r"""
The online gradient computation algorithm with the diagonal approximation and the parameter dimension complexity.
This algorithm computes the gradients of the weights with the diagonal approximation and the parameter dimension complexity.
Its algorithm is based on the RTRL algorithm, and has the following learning rule:
$$
\begin{aligned}
&\boldsymbol{\epsilon}^t \approx \mathbf{D}^t \boldsymbol{\epsilon}^{t-1}+\operatorname{diag}\left(\mathbf{D}_f^t\right) \otimes \mathbf{x}^t \\
& \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}}
\end{aligned}
$$
For more details, please see `the D-RTRL algorithm presented in our manuscript <https://www.biorxiv.org/content/10.1101/2024.09.24.614728v2>`_.
Note than the :py:class:`ParamDimVjpAlgorithm` is a subclass of :py:class:`brainstate.nn.Module`,
and it is sensitive to the context/mode of the computation. Particularly,
the :py:class:`ParamDimVjpAlgorithm` is sensitive to ``brainstate.mixin.Batching`` behavior.
This algorithm has the :math:`O(B\theta)` memory complexity, where :math:`\theta` is the number of parameters,
and :math:`B` the batch size.
For a convolutional layer, the algorithm computes the weight gradients with the :math:`O(B\theta)`
memory complexity, where :math:`\theta` is the dimension of the convolutional kernel.
For a Linear transformation layer, the algorithm computes the weight gradients with the :math:`O(BIO)``
computational complexity, where :math:`I` and :math:`O` are the number of input and output dimensions.
Parameters
-----------
model: brainstate.nn.Module
The model function, which receives the input arguments and returns the model output.
vjp_method: str, optional
The method for computing the VJP. It should be either "single-step" or "multi-step".
- "single-step": The VJP is computed at the current time step, i.e., $\partial L^t/\partial h^t$.
- "multi-step": The VJP is computed at multiple time steps, i.e., $\partial L^t/\partial h^{t-k}$,
where $k$ is determined by the data input.
name: str, optional
The name of the etrace algorithm.
mode: braintrace.mixin.Mode
The computing mode, indicating the batching behavior.
"""
# batch of weight gradients
etrace_bwg: Dict[ETraceWG_Key, brainstate.State]
def __init__(
self,
model: brainstate.nn.Module,
name: Optional[str] = None,
vjp_method: str = 'single-step',
fast_solve: bool = True,
normalize_matrix_spectrum: bool = False,
**kwargs,
):
super().__init__(model, name=name, vjp_method=vjp_method)
# ``fast_solve=True`` enables closed-form einsum kernels for
# mm/mv/elemwise primitives, replacing the nested-vmap legacy path.
# Conv / sparse / LoRA primitives always use the legacy path.
self.fast_solve = fast_solve
# When True, clip trace magnitudes > 1 each step via a branch-free
# ``v / max(max_abs, 1)``. Default False (disabled). The previous
# implementation applied this unconditionally to the instantaneous
# term only, which silently distorted gradients.
self.normalize_matrix_spectrum = normalize_matrix_spectrum
[docs]
def init_etrace_state(self, *args, **kwargs):
"""
Initialize the eligibility trace states of the etrace algorithm.
This method is needed after compiling the etrace graph. See
`.compile_graph()` for the details.
"""
# The states of batched weight gradients
self.etrace_bwg = dict()
for relation in self.graph.hidden_param_op_relations:
_init_param_dim_state(self.etrace_bwg, relation)
[docs]
def reset_state(self, batch_size: int = None, **kwargs):
"""
Reset the eligibility trace states.
"""
self.running_index.value = 0
_reset_state_in_a_dict(self.etrace_bwg, batch_size)
[docs]
def get_etrace_of(self, weight: brainstate.ParamState | Path) -> Dict:
"""
Get the eligibility trace of the given weight.
The eligibility trace contains the following structures:
"""
self._assert_compiled()
# get the wight id
weight_id = (
id(weight)
if isinstance(weight, brainstate.ParamState) else
id(self.graph_executor.path_to_states[weight])
)
find_this_weight = False
etraces = dict()
for relation in self.graph.hidden_param_op_relations:
relation: HiddenParamOpRelation
primary_state = next(iter(relation.trainable_param_states.values()), None)
if primary_state is None or id(primary_state) != weight_id:
continue
find_this_weight = True
# retrieve the etrace data
for group in relation.hidden_groups:
group: HiddenGroup
key = (id(relation.y_var), group.index)
etraces[key] = self.etrace_bwg[key].value
if not find_this_weight:
raise ValueError(f'Do not the etrace of the given weight: {weight}.')
return etraces
def _get_etrace_data(self) -> Dict:
"""Retrieve the current eligibility trace data from all trace states.
This method collects all eligibility trace values from the internal state dictionary,
extracting the current values from the brainstate.State objects that store them.
It returns these values in a dictionary with the same keys as the original state
dictionary, making the current trace values available for processing.
This is an internal method used in the parameter dimension eligibility trace algorithm
to access the current trace state for updates and gradient calculations.
Returns:
Dict[ETraceWG_Key, jax.Array]: A dictionary mapping eligibility trace keys to
their current values. Each key represents a specific trace component
(typically involving a parameter and hidden state relationship), and
the corresponding value represents the accumulated eligibility trace.
"""
return {
k: v.value
for k, v in self.etrace_bwg.items()
}
def _assign_etrace_data(self, etrace_vals: Dict) -> None:
"""Assign eligibility trace values to their corresponding state objects.
This method updates the internal eligibility trace state dictionary (etrace_bwg)
with new values from the provided dictionary. It iterates through each key-value
pair in the input dictionary and assigns the value to the corresponding state
object's value attribute.
This is an implementation of the abstract method from the parent class,
customized for the parameter dimension eligibility trace algorithm which
stores traces in a single dictionary rather than separate ones for inputs
and differential functions.
Args:
etrace_vals: Dict[ETraceWG_Key, jax.Array]
Dictionary mapping eligibility trace keys to their updated values.
Each key represents a specific parameter-hidden state relationship,
and the value represents the updated eligibility trace value.
Returns:
None
"""
for x, val in etrace_vals.items():
self.etrace_bwg[x].value = val
def _update_etrace_data(
self,
running_index: Optional[int],
hist_etrace_vals: Dict[ETraceWG_Key, PyTree],
hid2weight_jac_single_or_multi_times: Hid2WeightJacobian,
hid2hid_jac_single_or_multi_times: HiddenGroupJacobian,
weight_vals: Dict[Path, PyTree],
input_is_multi_step: bool,
) -> Dict[ETraceWG_Key, PyTree]:
"""Update eligibility trace data for the parameter dimension-based algorithm.
This method implements the core update equation for the D-RTRL algorithm's eligibility traces:
ε^t ≈ D^t·ε^{t-1} + diag(D_f^t)⊗x^t
It uses JAX's scan operation to efficiently process the historical trace values and
combines them with current Jacobians to compute updated traces according to the
parameter-dimension approximation approach.
Args:
running_index: Optional[int]
Current timestep counter, used for correcting exponential smoothing bias.
hist_etrace_vals: Dict[ETraceWG_Key, PyTree]
Dictionary containing historical eligibility trace values from previous timestep.
Keys are tuples identifying parameter-hidden state relationships.
hid2weight_jac_single_or_multi_times: Hid2WeightJacobian
Jacobians of hidden states with respect to weights at the current timestep.
Contains input gradients and differential function gradients.
hid2hid_jac_single_or_multi_times: HiddenGroupJacobian
Jacobians between hidden states (recurrent connections) at the current timestep.
weight_vals: Dict[Path, PyTree]
Dictionary mapping paths to current weight values in the model.
Returns:
Dict[ETraceWG_Key, PyTree]: Updated eligibility trace values dictionary with the
same structure as hist_etrace_vals but containing new values for the current timestep.
"""
scan_fn = partial(
_update_param_dim_etrace_scan_fn,
weight_path_to_vals=weight_vals,
hidden_param_op_relations=self.graph.hidden_param_op_relations,
normalize_matrix_spectrum=self.normalize_matrix_spectrum,
fast_solve=self.fast_solve,
)
if input_is_multi_step:
new_etrace = jax.lax.scan(
scan_fn,
hist_etrace_vals,
(
hid2weight_jac_single_or_multi_times[0],
hid2weight_jac_single_or_multi_times[1],
hid2hid_jac_single_or_multi_times,
)
)[0]
else:
new_etrace = scan_fn(
hist_etrace_vals,
(
hid2weight_jac_single_or_multi_times[0],
hid2weight_jac_single_or_multi_times[1],
hid2hid_jac_single_or_multi_times,
)
)[0]
return new_etrace
def _solve_weight_gradients(
self,
running_index: int,
etrace_h2w_at_t: Dict[ETraceWG_Key, PyTree],
dl_to_hidden_groups: Sequence[jax.Array],
weight_vals: Dict[WeightID, PyTree],
dl_to_nonetws_at_t: Dict[Path, PyTree],
dl_to_etws_at_t: Optional[Dict[Path, PyTree]],
):
"""Compute weight gradients using parameter dimension eligibility traces.
This method implements the parameter dimension D-RTRL algorithm's weight gradient
computation. It combines the eligibility traces with the gradients of the loss
with respect to hidden states to compute the full parameter gradients according to:
∇_θ L = ∑_{t' ∈ T} ∂L^{t'}/∂h^{t'} ∘ ε^{t'}
Where ε represents the eligibility traces and ∂L/∂h are the gradients of the loss
with respect to hidden states.
Args:
running_index: int
Current timestep counter used for bias correction.
etrace_h2w_at_t: Dict[ETraceWG_Key, PyTree]
Eligibility trace values at the current timestep, mapping parameter-hidden
state relationship keys to trace values.
dl_to_hidden_groups: Sequence[jax.Array]
Gradients of the loss with respect to hidden states at the current timestep.
weight_vals: Dict[WeightID, PyTree]
Current values of all weights in the model.
dl_to_nonetws_at_t: Dict[Path, PyTree]
Gradients of non-eligibility trace parameters at the current timestep.
dl_to_etws_at_t: Optional[Dict[Path, PyTree]]
Optional additional gradients for eligibility trace parameters at the
current timestep.
Returns:
Dict[Path, PyTree]: Dictionary mapping parameter paths to their gradient values.
"""
dG_weights = {path: None for path in self.param_states}
# update the etrace weight gradients
_solve_param_dim_weight_gradients(
etrace_h2w_at_t,
dG_weights,
dl_to_hidden_groups,
self.graph.hidden_param_op_relations,
weight_vals,
fast_solve=self.fast_solve,
)
# update the non-etrace weight gradients
for path, dg in dl_to_nonetws_at_t.items():
_update_dict(dG_weights, path, dg)
# update the etrace parameters when "dl_to_etws_at_t" is not None
if dl_to_etws_at_t is not None:
for path, dg in dl_to_etws_at_t.items():
_update_dict(dG_weights, path, dg, error_when_no_key=True)
return dG_weights
D_RTRL = ParamDimVjpAlgorithm