ETP Primitives Deep Dive#

Introduction#

ETP (Eligibility Trace Propagation) primitives are JAX custom primitives that mark weight operations in the computational graph. They replace the old ETraceOp / JIT-name-matching system with a cleaner, more robust approach.

Key design principles:

  • Type identity, not string matching. The compiler identifies ETP primitives by checking eqn.primitive in ETP_PRIMITIVES — a set-membership test on the primitive object itself. This is more reliable than the old approach of matching JIT function names.

  • All JAX rules are auto-derived. Each primitive’s impl delegates to standard JAX ops (e.g., x @ w, jax.lax.conv_general_dilated). The register_primitive() helper automatically derives abstract_eval, MLIR lowering, JVP, transpose, and batching rules from the implementation. No hand-written derivative formulas are needed.

  • Only ETP-specific rules need hand-writing. Four global dictionaries capture the online-learning-specific logic:

    • ETP_RULES_YW_TO_W — D-RTRL trace propagation (the \(\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}\) term)

    • ETP_RULES_XY_TO_DW — instantaneous hidden-to-weight Jacobian (the \(\operatorname{diag}(\mathbf{D}_f^t) \otimes \mathbf{x}^t\) term)

    • ETP_RULES_INIT_DRTRL — D-RTRL parameter-dimension trace initialiser

    • ETP_RULES_INIT_PP — pp-prop / ES-D-RTRL output-dimension df-trace initialiser

  • Primitive-based parameter selection. A parameter participates in ETP if and only if it flows through an ETP primitive (braintrace.matmul, braintrace.element_wise, etc.). Parameters used with regular JAX ops are automatically excluded — no special parameter class is needed.

  • N-trainable-inputs per primitive (dict rule API). A single primitive may declare several trainable inputs at once — e.g. {weight, bias} for Linear, {B, A, bias} for LoRA. The four ETP rules consume and return Dict[str, Array] so the executor routes gradients to every owning ParamState in one pass.

import jax
import jax.numpy as jnp

import braintrace
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

The Five Primitive Functions#

braintrace provides five user-facing ETP primitive functions:

Function

Underlying primitives

Purpose

braintrace.matmul

etp_mm_p (batched) / etp_mv_p (unbatched)

Dense matrix multiplication

braintrace.element_wise

etp_elemwise_p

Element-wise (diagonal) weight ops

braintrace.conv

etp_conv_p

Convolution

braintrace.sparse_matmul

etp_sp_mm_p / etp_sp_mv_p

Sparse matrix multiplication

braintrace.lora_matmul

etp_lora_mm_p / etp_lora_mv_p

LoRA (Low-Rank Adaptation) matmul

Each function auto-dispatches between batched and unbatched variants based on input dimensionality.

1. braintrace.matmul(x, weight, bias=None) – Dense Matrix Multiplication#

Computes \(y = x \, @ \, w \; (+ b)\).

Auto-dispatches based on x.ndim:

  • x.ndim >= 2 –> etp_mm_p (batched): expects x of shape (batch, in_features)

  • x.ndim == 1 –> etp_mv_p (unbatched): expects x of shape (in_features,)

# Batched matmul: x has shape (batch, in_features)
x_batched = jnp.ones((4, 3))    # batch=4, in_features=3
w = jnp.ones((3, 5))            # in_features=3, out_features=5

y_batched = braintrace.matmul(x_batched, w)
print("Batched output shape:", y_batched.shape)   # (4, 5)

# Unbatched matmul: x has shape (in_features,)
x_single = jnp.ones((3,))       # in_features=3

y_single = braintrace.matmul(x_single, w)
print("Unbatched output shape:", y_single.shape)   # (5,)
Batched output shape: (4, 5)
Unbatched output shape: (5,)
# With bias
b = jnp.zeros((5,))

y_with_bias = braintrace.matmul(x_batched, w, bias=b)
print("With bias:", y_with_bias.shape)              # (4, 5)
With bias: (4, 5)

2. braintrace.element_wise(weight, fn=lambda w: w) — Element-wise Operation#

Applies fn to the weight and passes the result through a marker primitive. The operation is treated as diagonal in the hidden-state space.

\[y = \texttt{fn}(w)\]

fn defaults to the identity (lambda w: w); supply any JAX-differentiable function when you want a non-trivial transformation.

Common use cases:

  • Gating mechanisms in RNNs (learnable gate biases)

  • Learnable time constants or thresholds in spiking neural networks

  • Any parameter that enters the computation element-wise

Note: etp_elemwise_p is the only primitive registered with gradient_enabled=True. The compiler descends into it when walking y -> h, so it does not act as a tail boundary for upstream ETP weights. See the gradient_enabled Flag section below for details.

# Identity (default fn): just marks the weight for ETP
w_gate = jnp.array([0.5, -0.3, 0.8, 0.1])

y_identity = braintrace.element_wise(w_gate)
print("Identity:", y_identity)

# With a transformation function
y_sigmoid = braintrace.element_wise(w_gate, fn=jax.nn.sigmoid)
print("Sigmoid:", y_sigmoid)

# With absolute value (e.g., enforcing positive time constants)
y_abs = braintrace.element_wise(w_gate, fn=jnp.abs)
print("Abs:", y_abs)
Identity: [ 0.5 -0.3  0.8  0.1]
Sigmoid: [0.62245935 0.4255575  0.6899745  0.5249792 ]
Abs: [0.5 0.3 0.8 0.1]

3. braintrace.conv(x, kernel, bias=None, *, strides, padding, ...) – Convolution#

ETP-aware convolution that wraps jax.lax.conv_general_dilated. Computes:

\[y = \text{conv}(x, \text{kernel}) \; (+ b)\]

Important: Always expects a batch dimension on x.

Supports all parameters of jax.lax.conv_general_dilated: strides, padding, lhs_dilation, rhs_dilation, feature_group_count, batch_group_count, and dimension_numbers.

# 1D convolution example
# x: (batch, spatial, channels) with dimension_numbers
x_1d = jnp.ones((2, 16, 3))         # batch=2, length=16, in_channels=3
kernel_1d = jnp.ones((4, 3, 8))     # kernel_size=4, in_channels=3, out_channels=8

y_conv = braintrace.conv(
    x_1d, kernel_1d,
    strides=(1,),
    padding='SAME',
    dimension_numbers=('NWC', 'WIO', 'NWC'),
)
print("Conv1D output shape:", y_conv.shape)  # (2, 16, 8)
Conv1D output shape: (2, 16, 8)
# 2D convolution example
x_2d = jnp.ones((2, 32, 32, 3))          # batch=2, H=32, W=32, in_channels=3
kernel_2d = jnp.ones((3, 3, 3, 16))      # kH=3, kW=3, in_channels=3, out_channels=16

y_conv2d = braintrace.conv(
    x_2d, kernel_2d,
    strides=(1, 1),
    padding='SAME',
    dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
)
print("Conv2D output shape:", y_conv2d.shape)  # (2, 32, 32, 16)
Conv2D output shape: (2, 32, 32, 16)

4. braintrace.sparse_matmul(x, weight_data, *, sparse_mat, bias=None) – Sparse Matmul#

ETP-aware sparse matrix multiplication. Computes:

\[y = x \, @ \, \text{sparse}(w) \; (+ b)\]

The sparse_mat argument provides the sparse structure (indices), while weight_data contains only the non-zero values. This is useful for models with sparse connectivity patterns, such as biologically plausible neural networks or graph neural networks.

import saiunit as u
from saiunit import sparse as ss

# Create a sparse connectivity matrix
dense_w = jnp.where(
    jax.random.uniform(jax.random.PRNGKey(0), (50, 50)) < 0.1,
    jax.random.normal(jax.random.PRNGKey(1), (50, 50)),
    0.0
)
sparse_mat = ss.CSR.fromdense(dense_w)

# The learnable parameter is just the non-zero data
weight_data = sparse_mat.data

x_sp = jnp.ones((4, 50))  # batch=4, features=50
y_sp = braintrace.sparse_matmul(x_sp, weight_data, sparse_mat=sparse_mat)
print("Sparse matmul output shape:", y_sp.shape)  # (4, 50)
Sparse matmul output shape: (4, 50)

5. braintrace.lora_matmul(x, B, A, *, alpha=1.0, bias=None) – LoRA Matmul#

Low-Rank Adaptation matmul. Computes:

\[y = \alpha \cdot x \, @ \, B \, @ \, A \; (+ b)\]

where \(B \in \mathbb{R}^{\text{in} \times \text{rank}}\) and \(A \in \mathbb{R}^{\text{rank} \times \text{out}}\) are low-rank factors. This is useful for parameter-efficient fine-tuning of large models, where only the low-rank factors are trained.

in_features, out_features, rank = 64, 32, 4

B = jax.random.normal(jax.random.PRNGKey(0), (in_features, rank)) * 0.01
A = jax.random.normal(jax.random.PRNGKey(1), (rank, out_features)) * 0.01

x_lora = jnp.ones((8, in_features))  # batch=8

y_lora = braintrace.lora_matmul(x_lora, B, A, alpha=2.0)
print("LoRA output shape:", y_lora.shape)  # (8, 32)
print("LoRA output (first sample):", y_lora[0])
LoRA output shape: (8, 32)
LoRA output (first sample): [ 1.2993842e-04  2.9886682e-03 -5.7171844e-04  1.7000248e-03
  2.8961061e-03 -1.4734608e-03  1.5826351e-03  1.1692893e-04
 -5.4561032e-04 -1.3646147e-03  1.0654312e-04  2.0160272e-03
  3.1371990e-03  1.8710974e-03  2.9124888e-03 -8.2047645e-04
  9.7736251e-04 -1.1875220e-03 -2.1796541e-03 -5.8191392e-05
 -1.5415263e-03 -1.3381819e-03 -2.1044153e-03 -8.4472442e-04
  1.6430757e-05 -2.9564608e-04  4.5550600e-04 -1.6011632e-03
 -1.4516843e-03 -1.3209208e-03  3.2850087e-04 -6.0276774e-04]

Physical Units (saiunit / Quantity) Support#

Every user-facing ETP function in braintracematmul, conv, element_wise, sparse_matmul, lora_matmul — accepts saiunit.Quantity inputs transparently. The wrapper

  1. splits each quantity into a plain-array mantissa and a unit,

  2. binds the primitive on the mantissas only,

  3. re-attaches the combined unit to the result with u.maybe_decimal.

This keeps the primitives themselves unit-free (JAX sees only arrays), while users write physical quantities naturally. Bias is re-scaled into the combined x × weight unit before bind so the addition is dimensionally valid.

import saiunit as u

# Quantity-valued inputs pass through unchanged.
x_q = jnp.ones((4, 3)) * u.volt          # shape (4, 3), unit = V
w_q = jnp.ones((3, 5)) * u.siemens       # shape (3, 5), unit = S
b_q = jnp.zeros((5,)) * u.amp             # must match V * S = A

y_q = braintrace.matmul(x_q, w_q, bias=b_q)
print("Output:", y_q)
print("Unit :", u.get_unit(y_q))
Output: [[3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3.]
 [3. 3. 3. 3. 3.]] A
Unit : A

JAX Compatibility#

All ETP primitives are fully compatible with JAX transformations. Since register_primitive() auto-derives JIT, grad, vmap, and JVP rules from the implementation, they work seamlessly with the standard JAX API.

x = jnp.ones((4, 3))
w = jnp.ones((3, 5))

# ---- JIT compilation ----
y_jit = jax.jit(braintrace.matmul)(x, w)
print("JIT output shape:", y_jit.shape)
JIT output shape: (4, 5)
# ---- Gradient computation ----
grad_fn = jax.grad(lambda w: jnp.sum(braintrace.matmul(x, w)))
dw = grad_fn(w)
print("Gradient shape:", dw.shape)
print("Gradient values:\n", dw)
Gradient shape: (3, 5)
Gradient values:
 [[4. 4. 4. 4. 4.]
 [4. 4. 4. 4. 4.]
 [4. 4. 4. 4. 4.]]
# ---- Vectorized mapping (vmap) ----
# vmap over a batch of inputs, each of shape (4, 3)
xs = jnp.ones((8, 4, 3))  # 8 different batches
vmap_fn = jax.vmap(lambda x_i: braintrace.matmul(x_i, w))
ys = vmap_fn(xs)
print("vmap output shape:", ys.shape)  # (8, 4, 5)
vmap output shape: (8, 4, 5)
# ---- JVP (forward-mode differentiation) ----
primals = (x, w)
tangents = (jnp.ones_like(x), jnp.ones_like(w))

y_primal, y_tangent = jax.jvp(braintrace.matmul, primals, tangents)
print("JVP primal shape:", y_primal.shape)
print("JVP tangent shape:", y_tangent.shape)
JVP primal shape: (4, 5)
JVP tangent shape: (4, 5)
# ---- Composability: JIT + grad + vmap ----
@jax.jit
def batched_grad(xs, w):
    """Compute per-sample gradients w.r.t. the weight."""
    def single_grad(x_i):
        return jax.grad(lambda w_: jnp.sum(braintrace.matmul(x_i, w_)))(w)
    return jax.vmap(single_grad)(xs)

xs = jnp.ones((8, 4, 3))
per_sample_grads = batched_grad(xs, w)
print("Per-sample gradients shape:", per_sample_grads.shape)  # (8, 3, 5)
Per-sample gradients shape: (8, 3, 5)

Argument Conventions#

Every ETP primitive follows specific conventions for its input variables (invars) and static parameters. Understanding these conventions is essential when working with the compiler or adding custom primitives.

Invar layout#

Primitive

invars[0]

invars[1]

invars[2]

invars[3]

Static params

etp_mm_p / etp_mv_p

input x

weight W

bias b (opt)

has_bias

etp_elemwise_p

processed y

(none)

etp_conv_p

input x

kernel W

bias b (opt)

has_bias, strides, padding, lhs_dilation, rhs_dilation, feature_group_count, batch_group_count, dimension_numbers

etp_sp_mm_p / etp_sp_mv_p

input x

weight data

bias b (opt)

sparse_mat, has_bias

etp_lora_mm_p / etp_lora_mv_p

input x

matrix B

matrix A

bias b (opt)

alpha, has_bias

trainable_invars_fn — the N-trainable-input contract#

Instead of hard-coding a single weight_invar_index, each primitive exposes a function

trainable_invars_fn: Callable[[dict], Dict[str, int]]

which maps the equation’s static params onto {trainable_name: invar_index}. The compiler calls it at analysis time to discover every trainable input and to route gradients to the owning ParamState pytree leaf.

Built-in examples:

Primitive

has_bias=False

has_bias=True

etp_mm_p / etp_mv_p

{'weight': 1}

{'weight': 1, 'bias': 2}

etp_conv_p

{'weight': 1}

{'weight': 1, 'bias': 2}

etp_sp_mm_p / etp_sp_mv_p

{'weight': 1}

{'weight': 1, 'bias': 2}

etp_lora_mm_p / etp_lora_mv_p

{'lora_b': 1, 'lora_a': 2}

{'lora_b': 1, 'lora_a': 2, 'bias': 3}

etp_elemwise_p

{'weight': 0}

Notes:

  • The has_bias flag is a static parameter (not a traced value) that controls whether the optional bias argument is present.

  • For convolution, all jax.lax.conv_general_dilated parameters are passed as static params.

  • x_invar_index points to the non-trainable input; etp_elemwise_p sets it to None because the op has no separate input.

Rule Registries (dict API)#

ETP uses four global dictionaries to store operation-specific rules. These are the only things that need hand-writing — all standard JAX rules are auto-derived from the implementation function.

All four rules operate on Dict[str, Array] (keyed by the names returned by trainable_invars_fn) — except init_pp, which returns a single output-shaped array because pp-prop factorises the trace as \(\boldsymbol{\epsilon}_f \otimes \boldsymbol{\epsilon}_x\) and only needs one df-tensor per primitive output.

ETP_RULES_YW_TO_W — D-RTRL trace propagation#

yw_to_w(hidden_dim: Array, trace: Dict[str, Array], **static_params) -> Dict[str, Array]

Propagates the hidden-state cotangent \(\partial h/\partial y\) through the \(y \to W\) chain factor of the D-RTRL term \(\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}\). Applied per stored trace key.

ETP_RULES_XY_TO_DW — instantaneous hidden-to-weight Jacobian#

xy_to_dw(x: Array, hidden_dim: Array, weights: Dict[str, Array], **static_params) -> Dict[str, Array]

Returns \(\partial h / \partial W\) for every trainable key. This supplies the \(\operatorname{diag}(\mathbf{D}_f^t) \otimes \mathbf{x}^t\) term in D-RTRL and the solve-time pullback in ES-D-RTRL. Typical implementation: a single fused jax.vjp over a dict-valued forward function.

ETP_RULES_INIT_DRTRL — D-RTRL trace initialiser#

init_drtrl(x_var, y_var, weight_vars: Dict[str, Var], num_hidden_state: int) -> Dict[str, Array]

Returns a zero-filled Dict[str, Array] shaped to hold the parameter-dimension trace used by D_RTRL / ParamDimVjpAlgorithm. One leaf per trainable key.

ETP_RULES_INIT_PP — pp-prop / ES-D-RTRL df-trace initialiser#

init_pp(x_var, y_var, weight_vars: Dict[str, Var], num_hidden_state: int) -> Array

Returns a single zero-filled array shaped to hold the output-dimension df trace used by ES_D_RTRL / IODimVjpAlgorithm. The matching \(\boldsymbol{\epsilon}_x\) factor is managed separately by the executor’s x-trace dictionary.

The two INIT_* registries exist because the two algorithm families factorise the trace differently. Both are required for a primitive that should support both algorithms.

from braintrace._etrace_op import (
    ETP_RULES_YW_TO_W,
    ETP_RULES_XY_TO_DW,
    ETP_RULES_INIT_DRTRL,
    ETP_RULES_INIT_PP,
    ETP_PRIMITIVES,
    BATCHED_PRIMITIVES,
)

print("All ETP primitives:")
for p in sorted(ETP_PRIMITIVES, key=lambda p: p.name):
    batched_tag = " [batched]" if p in BATCHED_PRIMITIVES else ""
    print(f"  {p.name}{batched_tag}")

print("\nTrace propagation rules (ETP_RULES_YW_TO_W):")
for p in sorted(ETP_RULES_YW_TO_W.keys(), key=lambda p: p.name):
    print(f"  {p.name}")

print("\nWeight gradient rules (ETP_RULES_XY_TO_DW):")
for p in sorted(ETP_RULES_XY_TO_DW.keys(), key=lambda p: p.name):
    print(f"  {p.name}")

print("\nD-RTRL init rules (ETP_RULES_INIT_DRTRL):")
for p in sorted(ETP_RULES_INIT_DRTRL.keys(), key=lambda p: p.name):
    print(f"  {p.name}")

print("\npp_prop init rules (ETP_RULES_INIT_PP):")
for p in sorted(ETP_RULES_INIT_PP.keys(), key=lambda p: p.name):
    print(f"  {p.name}")
All ETP primitives:
  etp_conv [batched]
  etp_elemwise
  etp_lora_mm [batched]
  etp_lora_mv
  etp_mm [batched]
  etp_mv
  etp_sp_mm [batched]
  etp_sp_mv

Trace propagation rules (ETP_RULES_YW_TO_W):
  etp_conv
  etp_elemwise
  etp_lora_mm
  etp_lora_mv
  etp_mm
  etp_mv
  etp_sp_mm
  etp_sp_mv

Weight gradient rules (ETP_RULES_XY_TO_DW):
  etp_conv
  etp_elemwise
  etp_lora_mm
  etp_lora_mv
  etp_mm
  etp_mv
  etp_sp_mm
  etp_sp_mv

D-RTRL init rules (ETP_RULES_INIT_DRTRL):
  etp_conv
  etp_elemwise
  etp_lora_mm
  etp_lora_mv
  etp_mm
  etp_mv
  etp_sp_mm
  etp_sp_mv

pp_prop init rules (ETP_RULES_INIT_PP):
  etp_conv
  etp_elemwise
  etp_lora_mm
  etp_lora_mv
  etp_mm
  etp_mv
  etp_sp_mm
  etp_sp_mv

Adding a Custom Primitive#

Adding a new ETP primitive takes only a few steps. Here we create a scaled matrix multiplication with an optional bias as an example:

\[y = \text{scale} \cdot (x \, @ \, W) \; (+ b).\]

The example exercises the whole dict rule API: both the weight and bias branches are wired end-to-end.

import braintrace
from braintrace import register_primitive


# Step 1: Define the implementation.
# Plain JAX function — no special annotations needed.
def _scaled_matmul_impl(*args, scale=1.0, has_bias=False):
    x, w = args[0], args[1]
    y = scale * (x @ w)
    if has_bias:
        y = y + args[2]
    return y


# Step 2: Register as an ETP primitive.
# register_primitive() returns an ``ETPPrimitive`` and auto-derives all
# standard JAX rules (abstract_eval, lowering, JVP, transpose, batching).
scaled_mm_p = register_primitive('etp_scaled_mm', _scaled_matmul_impl, batched=True)

print("Primitive registered:", scaled_mm_p)
print("Type:", type(scaled_mm_p).__name__)
Primitive registered: etp_scaled_mm
Type: ETPPrimitive
# Step 3: Register the four ETP-specific rules (dict API).
# Each rule accepts / returns ``Dict[str, Array]`` keyed by the names
# in ``trainable_invars_fn`` — here ``'weight'`` and (optionally) ``'bias'``.


def _scaled_trainable_invars(params):
    """Tell the compiler which invars are trainable."""
    base = {'weight': 1}
    if params.get('has_bias', False):
        base['bias'] = 2
    return base


def _scaled_yw_to_w(hidden_dim, trace, *, scale=1.0, has_bias=False):
    # y = scale * x @ w + b
    #   -> ∂y/∂w along the "out" axis is scaled by `scale`; the y→w chain
    #      link is still elementwise along `out` axis (singleton at axis=-2).
    out = {'weight': trace['weight'] * jnp.expand_dims(hidden_dim, axis=-2) * scale}
    if has_bias:
        out['bias'] = trace['bias'] * hidden_dim
    return out


def _scaled_xy_to_dw(x, hidden_dim, weights, *, scale=1.0, has_bias=False):
    # Single fused VJP over a dict-valued forward function — returns
    # gradients for both 'weight' and 'bias' in one pass.
    def _fwd(w_dict):
        y = scale * (x @ w_dict['weight'])
        if has_bias:
            y = y + w_dict['bias']
        return y
    _, vjp_fn = jax.vjp(_fwd, weights)
    return vjp_fn(hidden_dim)[0]


def _scaled_init_drtrl(x_var, y_var, weight_vars, num_hidden_state):
    """D-RTRL parameter-dim trace: one leaf per trainable key."""
    batch = x_var.aval.shape[0]
    out = {
        'weight': jnp.zeros(
            (batch, *weight_vars['weight'].aval.shape, num_hidden_state)
        )
    }
    if 'bias' in weight_vars:
        out['bias'] = jnp.zeros(
            (batch, *weight_vars['bias'].aval.shape, num_hidden_state)
        )
    return out


def _scaled_init_pp(x_var, y_var, weight_vars, num_hidden_state):
    """pp-prop df trace: single array shaped like the output."""
    return jnp.zeros(
        (*y_var.aval.shape, num_hidden_state),
        dtype=y_var.aval.dtype,
    )


scaled_mm_p.register_etp_rules(
    yw_to_w=_scaled_yw_to_w,
    xy_to_dw=_scaled_xy_to_dw,
    init_drtrl=_scaled_init_drtrl,
    init_pp=_scaled_init_pp,
)

# Shorthand: every ``register_*`` method also exists as a standalone call,
# and ``ETPPrimitiveSpec`` (below) bundles the whole thing in one record.

print("yw_to_w registered:  ", scaled_mm_p in ETP_RULES_YW_TO_W)
print("xy_to_dw registered: ", scaled_mm_p in ETP_RULES_XY_TO_DW)
print("init_drtrl registered:", scaled_mm_p in ETP_RULES_INIT_DRTRL)
print("init_pp registered:  ", scaled_mm_p in ETP_RULES_INIT_PP)
yw_to_w registered:   True
xy_to_dw registered:  True
init_drtrl registered: True
init_pp registered:   True
# Step 4: Use the custom primitive via ``primitive.bind()``.

x = jnp.ones((4, 3))
w = jnp.ones((3, 5))

y = scaled_mm_p.bind(x, w, scale=2.0, has_bias=False)
y_expected = 2.0 * (x @ w)

print("Output shape :", y.shape)
print("Matches 2·xw :", bool(jnp.allclose(y, y_expected)))

# With bias:
b = jnp.full((5,), 0.1)
y_bias = scaled_mm_p.bind(x, w, b, scale=2.0, has_bias=True)
print("With bias    :", y_bias[0])
Output shape : (4, 5)
Matches 2·xw : True
With bias    : [6.1 6.1 6.1 6.1 6.1]
# All JAX transformations work automatically for the custom primitive.

# JIT
y_jit = jax.jit(lambda x, w: scaled_mm_p.bind(x, w, scale=2.0, has_bias=False))(x, w)
print("JIT works:", bool(jnp.allclose(y_jit, y_expected)))

# Grad
dw = jax.grad(lambda w: jnp.sum(scaled_mm_p.bind(x, w, scale=2.0, has_bias=False)))(w)
print("Grad shape:", dw.shape)

# Vmap
xs = jnp.ones((8, 4, 3))
ys = jax.vmap(lambda xi: scaled_mm_p.bind(xi, w, scale=2.0, has_bias=False))(xs)
print("Vmap output shape:", ys.shape)
JIT works: True
Grad shape: (3, 5)
Vmap output shape: (8, 4, 5)

Compiler integration. The class-style registration above is enough for direct primitive.bind() use, JIT, grad, vmap, and JVP. For the primitive to be discovered by the ETP compiler (compile_etrace_graph, D_RTRL, ES_D_RTRL), you must also publish the trainable_invars_fn and the invar layout via an ETPPrimitiveSpec — see the next section. _scaled_trainable_invars defined above is re-used there.

Spec-based Registration#

When a primitive is intended for the ETP compiler, use the spec form. An ETPPrimitiveSpec bundles the implementation, all four rules, the invar/outvar layout, and the trainable_invars_fn into one frozen dataclass. Passing it to register_primitive_spec wires everything up and records the spec in ETP_PRIMITIVE_SPECS so the compiler can query it via get_primitive_spec.

Spec fields:

Field

Purpose

name

Primitive name

impl

Plain JAX forward function

yw_to_w, xy_to_dw, init_drtrl, init_pp

The four ETP rules

trainable_invars_fn

params -> {trainable_name: invar_index} — required

x_invar_index

Position of the non-trainable input, or None for identity-like ops

y_outvar_index

Position of y in eqn.outvars (default 0)

batched

Batched-input primitive?

gradient_enabled

Compiler traverses this primitive when walking y h (default False; set only for identity-like ops)

The spec form is equivalent to the class-based form — pick whichever style fits your codebase.

import braintrace


def _spec_impl(*args, scale=1.0, has_bias=False):
    x, w = args[0], args[1]
    y = scale * (x @ w)
    if has_bias:
        y = y + args[2]
    return y


def _spec_trainable_invars(params):
    base = {'weight': 1}
    if params.get('has_bias', False):
        base['bias'] = 2
    return base


def _spec_yw_to_w(hidden_dim, trace, *, scale=1.0, has_bias=False):
    out = {'weight': trace['weight'] * jnp.expand_dims(hidden_dim, axis=-2) * scale}
    if has_bias:
        out['bias'] = trace['bias'] * hidden_dim
    return out


def _spec_xy_to_dw(x, hidden_dim, weights, *, scale=1.0, has_bias=False):
    def _fwd(w_dict):
        y = scale * (x @ w_dict['weight'])
        if has_bias:
            y = y + w_dict['bias']
        return y
    _, vjp_fn = jax.vjp(_fwd, weights)
    return vjp_fn(hidden_dim)[0]


def _spec_init_drtrl(x_var, y_var, weight_vars, n):
    batch = x_var.aval.shape[0]
    out = {
        'weight': jnp.zeros(
            (batch, *weight_vars['weight'].aval.shape, n)
        )
    }
    if 'bias' in weight_vars:
        out['bias'] = jnp.zeros(
            (batch, *weight_vars['bias'].aval.shape, n)
        )
    return out


def _spec_init_pp(x_var, y_var, weight_vars, n):
    return jnp.zeros((*y_var.aval.shape, n), dtype=y_var.aval.dtype)


spec = braintrace.ETPPrimitiveSpec(
    name='etp_spec_demo',
    impl=_spec_impl,
    yw_to_w=_spec_yw_to_w,
    xy_to_dw=_spec_xy_to_dw,
    init_drtrl=_spec_init_drtrl,
    init_pp=_spec_init_pp,
    trainable_invars_fn=_spec_trainable_invars,
    x_invar_index=0,   # ``invars[0]`` is the input; trainable invars are
                       # fully described by ``trainable_invars_fn`` above.
    batched=True,
)

spec_p = braintrace.register_primitive_spec(spec)

# The spec is recoverable from the primitive at any time:
assert braintrace.get_primitive_spec(spec_p) is spec

# Quick sanity check of the invar layout:
print("trainable (no bias):", spec.resolve_trainable_invars({'has_bias': False}))
print("trainable (bias)   :", spec.resolve_trainable_invars({'has_bias': True}))

# And the primitive binds normally:
y_spec = spec_p.bind(x, w, scale=3.0, has_bias=False)
print("spec bind output  :", y_spec.shape)
trainable (no bias): {'weight': 1}
trainable (bias)   : {'weight': 1, 'bias': 2}
spec bind output  : (4, 5)

The gradient_enabled Flag#

register_primitive() accepts a gradient_enabled keyword (default False). It controls how the compiler treats this primitive when walking from a weight’s output back to a hidden state.

gradient_enabled

Compiler behaviour

Example

False (default)

Treats the primitive as a tail boundary. A preceding ETP weight whose only path to a hidden state passes through this primitive is excluded from ETP, because per-primitive ETP rules cannot express weight-then-weight composition.

All trainable matmul/conv/sparse/LoRA primitives use this.

True

The primitive is identity-like and may sit on the tail of the y -> h walk. Its presence does not exclude an upstream ETP weight.

Only etp_elemwise_p – intended for gating biases, learnable thresholds, etc.

Use gradient_enabled=True only when the primitive’s xy_to_dw rule is itself an identity-like passthrough; mark all genuinely trainable ops with the default. The “weight -> weight -> hidden” exclusion is what makes per-primitive ETP rules sound – see advanced/limitations.ipynb for a worked example with GRUCell (3 Linears, only 2 ETP relations).

Integrating a Primitive with Online Learning#

Marking a weight operation with a braintrace.* primitive is the only thing a model has to do to opt that parameter into online learning. The compiler then walks the jaxpr, finds every ETP primitive, connects it to the downstream hidden states, and builds the eligibility-trace machinery for either D_RTRL (parameter-dim trace) or ES_D_RTRL / pp_prop (IO-dim trace).

Rule of thumb

Goal

Use

Include a parameter in online learning

braintrace.matmul(x, W) (or conv, sparse_matmul, lora_matmul, element_wise)

Exclude a parameter from online learning

regular JAX op: x @ W, lax.conv_general_dilated, …

The short example below wires a vanilla RNN into D_RTRL: only the recurrent weight is marked with braintrace.matmul, so only it receives an eligibility trace. The input weight uses plain @ and is learned by BPTT through the unrolled scan.

import brainstate


class TinyRNN(brainstate.nn.Module):
    def __init__(self, in_dim=4, hid_dim=6):
        super().__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        # Recurrent weight: ETP-enabled (online learning via D-RTRL).
        self.W_rec = brainstate.ParamState(
            0.1 * jax.random.normal(jax.random.PRNGKey(0), (hid_dim, hid_dim))
        )
        # Input weight: plain matmul, learned via BPTT instead.
        self.W_in = brainstate.ParamState(
            0.1 * jax.random.normal(jax.random.PRNGKey(1), (in_dim, hid_dim))
        )

    def init_state(self, batch_size=None, **kwargs):
        # ``HiddenState`` is what the ETP compiler traces through.
        self.h = brainstate.HiddenState(
            jnp.zeros((batch_size or 1, self.hid_dim))
        )

    def update(self, x):
        # W_in is NOT marked -> excluded from ETP.
        input_drive = x @ self.W_in.value
        # W_rec IS marked -> included in ETP.
        rec_drive = braintrace.matmul(self.h.value, self.W_rec.value)
        self.h.value = jax.nn.tanh(input_drive + rec_drive)
        return self.h.value


model = TinyRNN(in_dim=4, hid_dim=6)
brainstate.nn.init_all_states(model, batch_size=2)

# Wrap the model in a D-RTRL algorithm and compile the ETP graph.
alg = braintrace.D_RTRL(model)
alg.compile_graph(jnp.zeros((2, model.in_dim)))

print("Compiled ETP relations:", len(alg.graph.hidden_param_op_relations))
for rel in alg.graph.hidden_param_op_relations:
    print("   primitive =", rel.primitive.name,
          "  trainable keys =", list(rel.trainable_vars.keys()))
Compiled ETP relations: 1
   primitive = etp_mm   trainable keys = ['weight']

Summary#

ETP primitives provide a clean, extensible foundation for online learning in recurrent networks:

  • 8 built-in primitives cover the most common use cases: dense matmul (mm/mv), element-wise ops, convolution, sparse matmul (mm/mv), and LoRA matmul (mm/mv).

  • Dict rule API — every primitive declares its full set of trainable inputs via trainable_invars_fn, and the four ETP rules consume and return Dict[str, Array]. A single primitive can own several ParamState objects (e.g. weight + bias, or \(B + A + b\) in LoRA) and the executor routes gradients to each in one pass.

  • Custom primitives can be added in a few dozen lines: implement the forward function, call register_primitive (class style) or build an ETPPrimitiveSpec (compiler-ready), then hand-write the four ETP rules.

  • All JAX transformations (JIT, grad, vmap, JVP) work automatically — only the four online-learning-specific rules need hand-writing.

  • Parameter selection is primitive-based — every brainstate.ParamState is eligible for ETP, and participation depends only on whether a braintrace.* ETP primitive consumed it. Use gradient_enabled=True exclusively for identity-like ops such as etp_elemwise_p.

  • Saiunit quantities are handled transparently by every user-facing wrapper.

Where to look for the math:

Rule

Algorithm term

Source with derivation

xy_to_dw

\(\operatorname{diag}(\mathbf{D}_f^t) \otimes \mathbf{x}^t\)

docstrings in braintrace/_etrace_op/{dense,conv,elemwise,sparse,lora}.py

yw_to_w

\(\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}\) (\(y \to W\) link)

same files

init_drtrl

param-dim trace shape

same files

init_pp

output-dim df-trace shape

same files

Further reading: advanced/limitations.ipynb explains the non-parametric-tail invariant and walks through GRUCell (3 Linears, only 2 ETP relations).