ETP Primitives Deep Dive#
Introduction#
ETP (Eligibility Trace Propagation) primitives are JAX custom primitives that mark weight operations in the computational graph. They replace the old ETraceOp / JIT-name-matching system with a cleaner, more robust approach.
Key design principles:
Type identity, not string matching. The compiler identifies ETP primitives by checking
eqn.primitive in ETP_PRIMITIVES— a set-membership test on the primitive object itself. This is more reliable than the old approach of matching JIT function names.All JAX rules are auto-derived. Each primitive’s
impldelegates to standard JAX ops (e.g.,x @ w,jax.lax.conv_general_dilated). Theregister_primitive()helper automatically derivesabstract_eval, MLIR lowering, JVP, transpose, and batching rules from the implementation. No hand-written derivative formulas are needed.Only ETP-specific rules need hand-writing. Four global dictionaries capture the online-learning-specific logic:
ETP_RULES_YW_TO_W— D-RTRL trace propagation (the \(\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}\) term)ETP_RULES_XY_TO_DW— instantaneous hidden-to-weight Jacobian (the \(\operatorname{diag}(\mathbf{D}_f^t) \otimes \mathbf{x}^t\) term)ETP_RULES_INIT_DRTRL— D-RTRL parameter-dimension trace initialiserETP_RULES_INIT_PP— pp-prop / ES-D-RTRL output-dimension df-trace initialiser
Primitive-based parameter selection. A parameter participates in ETP if and only if it flows through an ETP primitive (
braintrace.matmul,braintrace.element_wise, etc.). Parameters used with regular JAX ops are automatically excluded — no special parameter class is needed.N-trainable-inputs per primitive (dict rule API). A single primitive may declare several trainable inputs at once — e.g.
{weight, bias}for Linear,{B, A, bias}for LoRA. The four ETP rules consume and returnDict[str, Array]so the executor routes gradients to every owningParamStatein one pass.
import jax
import jax.numpy as jnp
import braintrace
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
The Five Primitive Functions#
braintrace provides five user-facing ETP primitive functions:
Function |
Underlying primitives |
Purpose |
|---|---|---|
|
|
Dense matrix multiplication |
|
|
Element-wise (diagonal) weight ops |
|
|
Convolution |
|
|
Sparse matrix multiplication |
|
|
LoRA (Low-Rank Adaptation) matmul |
Each function auto-dispatches between batched and unbatched variants based on input dimensionality.
1. braintrace.matmul(x, weight, bias=None) – Dense Matrix Multiplication#
Computes \(y = x \, @ \, w \; (+ b)\).
Auto-dispatches based on x.ndim:
x.ndim >= 2–>etp_mm_p(batched): expectsxof shape(batch, in_features)x.ndim == 1–>etp_mv_p(unbatched): expectsxof shape(in_features,)
# Batched matmul: x has shape (batch, in_features)
x_batched = jnp.ones((4, 3)) # batch=4, in_features=3
w = jnp.ones((3, 5)) # in_features=3, out_features=5
y_batched = braintrace.matmul(x_batched, w)
print("Batched output shape:", y_batched.shape) # (4, 5)
# Unbatched matmul: x has shape (in_features,)
x_single = jnp.ones((3,)) # in_features=3
y_single = braintrace.matmul(x_single, w)
print("Unbatched output shape:", y_single.shape) # (5,)
Batched output shape: (4, 5)
Unbatched output shape: (5,)
# With bias
b = jnp.zeros((5,))
y_with_bias = braintrace.matmul(x_batched, w, bias=b)
print("With bias:", y_with_bias.shape) # (4, 5)
With bias: (4, 5)
2. braintrace.element_wise(weight, fn=lambda w: w) — Element-wise Operation#
Applies fn to the weight and passes the result through a marker primitive. The operation is treated as diagonal in the hidden-state space.
fn defaults to the identity (lambda w: w); supply any JAX-differentiable function when you want a non-trivial transformation.
Common use cases:
Gating mechanisms in RNNs (learnable gate biases)
Learnable time constants or thresholds in spiking neural networks
Any parameter that enters the computation element-wise
Note: etp_elemwise_p is the only primitive registered with gradient_enabled=True. The compiler descends into it when walking y -> h, so it does not act as a tail boundary for upstream ETP weights. See the gradient_enabled Flag section below for details.
# Identity (default fn): just marks the weight for ETP
w_gate = jnp.array([0.5, -0.3, 0.8, 0.1])
y_identity = braintrace.element_wise(w_gate)
print("Identity:", y_identity)
# With a transformation function
y_sigmoid = braintrace.element_wise(w_gate, fn=jax.nn.sigmoid)
print("Sigmoid:", y_sigmoid)
# With absolute value (e.g., enforcing positive time constants)
y_abs = braintrace.element_wise(w_gate, fn=jnp.abs)
print("Abs:", y_abs)
Identity: [ 0.5 -0.3 0.8 0.1]
Sigmoid: [0.62245935 0.4255575 0.6899745 0.5249792 ]
Abs: [0.5 0.3 0.8 0.1]
3. braintrace.conv(x, kernel, bias=None, *, strides, padding, ...) – Convolution#
ETP-aware convolution that wraps jax.lax.conv_general_dilated. Computes:
Important: Always expects a batch dimension on x.
Supports all parameters of jax.lax.conv_general_dilated: strides, padding, lhs_dilation, rhs_dilation, feature_group_count, batch_group_count, and dimension_numbers.
# 1D convolution example
# x: (batch, spatial, channels) with dimension_numbers
x_1d = jnp.ones((2, 16, 3)) # batch=2, length=16, in_channels=3
kernel_1d = jnp.ones((4, 3, 8)) # kernel_size=4, in_channels=3, out_channels=8
y_conv = braintrace.conv(
x_1d, kernel_1d,
strides=(1,),
padding='SAME',
dimension_numbers=('NWC', 'WIO', 'NWC'),
)
print("Conv1D output shape:", y_conv.shape) # (2, 16, 8)
Conv1D output shape: (2, 16, 8)
# 2D convolution example
x_2d = jnp.ones((2, 32, 32, 3)) # batch=2, H=32, W=32, in_channels=3
kernel_2d = jnp.ones((3, 3, 3, 16)) # kH=3, kW=3, in_channels=3, out_channels=16
y_conv2d = braintrace.conv(
x_2d, kernel_2d,
strides=(1, 1),
padding='SAME',
dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
)
print("Conv2D output shape:", y_conv2d.shape) # (2, 32, 32, 16)
Conv2D output shape: (2, 32, 32, 16)
4. braintrace.sparse_matmul(x, weight_data, *, sparse_mat, bias=None) – Sparse Matmul#
ETP-aware sparse matrix multiplication. Computes:
The sparse_mat argument provides the sparse structure (indices), while weight_data contains only the non-zero values. This is useful for models with sparse connectivity patterns, such as biologically plausible neural networks or graph neural networks.
import saiunit as u
from saiunit import sparse as ss
# Create a sparse connectivity matrix
dense_w = jnp.where(
jax.random.uniform(jax.random.PRNGKey(0), (50, 50)) < 0.1,
jax.random.normal(jax.random.PRNGKey(1), (50, 50)),
0.0
)
sparse_mat = ss.CSR.fromdense(dense_w)
# The learnable parameter is just the non-zero data
weight_data = sparse_mat.data
x_sp = jnp.ones((4, 50)) # batch=4, features=50
y_sp = braintrace.sparse_matmul(x_sp, weight_data, sparse_mat=sparse_mat)
print("Sparse matmul output shape:", y_sp.shape) # (4, 50)
Sparse matmul output shape: (4, 50)
5. braintrace.lora_matmul(x, B, A, *, alpha=1.0, bias=None) – LoRA Matmul#
Low-Rank Adaptation matmul. Computes:
where \(B \in \mathbb{R}^{\text{in} \times \text{rank}}\) and \(A \in \mathbb{R}^{\text{rank} \times \text{out}}\) are low-rank factors. This is useful for parameter-efficient fine-tuning of large models, where only the low-rank factors are trained.
in_features, out_features, rank = 64, 32, 4
B = jax.random.normal(jax.random.PRNGKey(0), (in_features, rank)) * 0.01
A = jax.random.normal(jax.random.PRNGKey(1), (rank, out_features)) * 0.01
x_lora = jnp.ones((8, in_features)) # batch=8
y_lora = braintrace.lora_matmul(x_lora, B, A, alpha=2.0)
print("LoRA output shape:", y_lora.shape) # (8, 32)
print("LoRA output (first sample):", y_lora[0])
LoRA output shape: (8, 32)
LoRA output (first sample): [ 1.2993842e-04 2.9886682e-03 -5.7171844e-04 1.7000248e-03
2.8961061e-03 -1.4734608e-03 1.5826351e-03 1.1692893e-04
-5.4561032e-04 -1.3646147e-03 1.0654312e-04 2.0160272e-03
3.1371990e-03 1.8710974e-03 2.9124888e-03 -8.2047645e-04
9.7736251e-04 -1.1875220e-03 -2.1796541e-03 -5.8191392e-05
-1.5415263e-03 -1.3381819e-03 -2.1044153e-03 -8.4472442e-04
1.6430757e-05 -2.9564608e-04 4.5550600e-04 -1.6011632e-03
-1.4516843e-03 -1.3209208e-03 3.2850087e-04 -6.0276774e-04]
Physical Units (saiunit / Quantity) Support#
Every user-facing ETP function in braintrace — matmul, conv, element_wise, sparse_matmul, lora_matmul — accepts saiunit.Quantity inputs transparently. The wrapper
splits each quantity into a plain-array mantissa and a unit,
binds the primitive on the mantissas only,
re-attaches the combined unit to the result with
u.maybe_decimal.
This keeps the primitives themselves unit-free (JAX sees only arrays), while users write physical quantities naturally. Bias is re-scaled into the combined x × weight unit before bind so the addition is dimensionally valid.
import saiunit as u
# Quantity-valued inputs pass through unchanged.
x_q = jnp.ones((4, 3)) * u.volt # shape (4, 3), unit = V
w_q = jnp.ones((3, 5)) * u.siemens # shape (3, 5), unit = S
b_q = jnp.zeros((5,)) * u.amp # must match V * S = A
y_q = braintrace.matmul(x_q, w_q, bias=b_q)
print("Output:", y_q)
print("Unit :", u.get_unit(y_q))
Output: [[3. 3. 3. 3. 3.]
[3. 3. 3. 3. 3.]
[3. 3. 3. 3. 3.]
[3. 3. 3. 3. 3.]] A
Unit : A
JAX Compatibility#
All ETP primitives are fully compatible with JAX transformations. Since register_primitive() auto-derives JIT, grad, vmap, and JVP rules from the implementation, they work seamlessly with the standard JAX API.
x = jnp.ones((4, 3))
w = jnp.ones((3, 5))
# ---- JIT compilation ----
y_jit = jax.jit(braintrace.matmul)(x, w)
print("JIT output shape:", y_jit.shape)
JIT output shape: (4, 5)
# ---- Gradient computation ----
grad_fn = jax.grad(lambda w: jnp.sum(braintrace.matmul(x, w)))
dw = grad_fn(w)
print("Gradient shape:", dw.shape)
print("Gradient values:\n", dw)
Gradient shape: (3, 5)
Gradient values:
[[4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]]
# ---- Vectorized mapping (vmap) ----
# vmap over a batch of inputs, each of shape (4, 3)
xs = jnp.ones((8, 4, 3)) # 8 different batches
vmap_fn = jax.vmap(lambda x_i: braintrace.matmul(x_i, w))
ys = vmap_fn(xs)
print("vmap output shape:", ys.shape) # (8, 4, 5)
vmap output shape: (8, 4, 5)
# ---- JVP (forward-mode differentiation) ----
primals = (x, w)
tangents = (jnp.ones_like(x), jnp.ones_like(w))
y_primal, y_tangent = jax.jvp(braintrace.matmul, primals, tangents)
print("JVP primal shape:", y_primal.shape)
print("JVP tangent shape:", y_tangent.shape)
JVP primal shape: (4, 5)
JVP tangent shape: (4, 5)
# ---- Composability: JIT + grad + vmap ----
@jax.jit
def batched_grad(xs, w):
"""Compute per-sample gradients w.r.t. the weight."""
def single_grad(x_i):
return jax.grad(lambda w_: jnp.sum(braintrace.matmul(x_i, w_)))(w)
return jax.vmap(single_grad)(xs)
xs = jnp.ones((8, 4, 3))
per_sample_grads = batched_grad(xs, w)
print("Per-sample gradients shape:", per_sample_grads.shape) # (8, 3, 5)
Per-sample gradients shape: (8, 3, 5)
Argument Conventions#
Every ETP primitive follows specific conventions for its input variables (invars) and static parameters. Understanding these conventions is essential when working with the compiler or adding custom primitives.
Invar layout#
Primitive |
|
|
|
|
Static params |
|---|---|---|---|---|---|
|
input |
weight |
bias |
— |
|
|
processed |
— |
— |
— |
(none) |
|
input |
kernel |
bias |
— |
|
|
input |
weight data |
bias |
— |
|
|
input |
matrix |
matrix |
bias |
|
trainable_invars_fn — the N-trainable-input contract#
Instead of hard-coding a single weight_invar_index, each primitive exposes a function
trainable_invars_fn: Callable[[dict], Dict[str, int]]
which maps the equation’s static params onto {trainable_name: invar_index}. The compiler calls it at analysis time to discover every trainable input and to route gradients to the owning ParamState pytree leaf.
Built-in examples:
Primitive |
|
|
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
— |
Notes:
The
has_biasflag is a static parameter (not a traced value) that controls whether the optional bias argument is present.For convolution, all
jax.lax.conv_general_dilatedparameters are passed as static params.x_invar_indexpoints to the non-trainable input;etp_elemwise_psets it toNonebecause the op has no separate input.
Rule Registries (dict API)#
ETP uses four global dictionaries to store operation-specific rules. These are the only things that need hand-writing — all standard JAX rules are auto-derived from the implementation function.
All four rules operate on Dict[str, Array] (keyed by the names returned by trainable_invars_fn) — except init_pp, which returns a single output-shaped array because pp-prop factorises the trace as \(\boldsymbol{\epsilon}_f \otimes \boldsymbol{\epsilon}_x\) and only needs one df-tensor per primitive output.
ETP_RULES_YW_TO_W — D-RTRL trace propagation#
yw_to_w(hidden_dim: Array, trace: Dict[str, Array], **static_params) -> Dict[str, Array]
Propagates the hidden-state cotangent \(\partial h/\partial y\) through the \(y \to W\) chain factor of the D-RTRL term \(\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}\). Applied per stored trace key.
ETP_RULES_INIT_DRTRL — D-RTRL trace initialiser#
init_drtrl(x_var, y_var, weight_vars: Dict[str, Var], num_hidden_state: int) -> Dict[str, Array]
Returns a zero-filled Dict[str, Array] shaped to hold the parameter-dimension trace used by D_RTRL / ParamDimVjpAlgorithm. One leaf per trainable key.
ETP_RULES_INIT_PP — pp-prop / ES-D-RTRL df-trace initialiser#
init_pp(x_var, y_var, weight_vars: Dict[str, Var], num_hidden_state: int) -> Array
Returns a single zero-filled array shaped to hold the output-dimension df trace used by ES_D_RTRL / IODimVjpAlgorithm. The matching \(\boldsymbol{\epsilon}_x\) factor is managed separately by the executor’s x-trace dictionary.
The two INIT_* registries exist because the two algorithm families factorise the trace differently. Both are required for a primitive that should support both algorithms.
from braintrace._etrace_op import (
ETP_RULES_YW_TO_W,
ETP_RULES_XY_TO_DW,
ETP_RULES_INIT_DRTRL,
ETP_RULES_INIT_PP,
ETP_PRIMITIVES,
BATCHED_PRIMITIVES,
)
print("All ETP primitives:")
for p in sorted(ETP_PRIMITIVES, key=lambda p: p.name):
batched_tag = " [batched]" if p in BATCHED_PRIMITIVES else ""
print(f" {p.name}{batched_tag}")
print("\nTrace propagation rules (ETP_RULES_YW_TO_W):")
for p in sorted(ETP_RULES_YW_TO_W.keys(), key=lambda p: p.name):
print(f" {p.name}")
print("\nWeight gradient rules (ETP_RULES_XY_TO_DW):")
for p in sorted(ETP_RULES_XY_TO_DW.keys(), key=lambda p: p.name):
print(f" {p.name}")
print("\nD-RTRL init rules (ETP_RULES_INIT_DRTRL):")
for p in sorted(ETP_RULES_INIT_DRTRL.keys(), key=lambda p: p.name):
print(f" {p.name}")
print("\npp_prop init rules (ETP_RULES_INIT_PP):")
for p in sorted(ETP_RULES_INIT_PP.keys(), key=lambda p: p.name):
print(f" {p.name}")
All ETP primitives:
etp_conv [batched]
etp_elemwise
etp_lora_mm [batched]
etp_lora_mv
etp_mm [batched]
etp_mv
etp_sp_mm [batched]
etp_sp_mv
Trace propagation rules (ETP_RULES_YW_TO_W):
etp_conv
etp_elemwise
etp_lora_mm
etp_lora_mv
etp_mm
etp_mv
etp_sp_mm
etp_sp_mv
Weight gradient rules (ETP_RULES_XY_TO_DW):
etp_conv
etp_elemwise
etp_lora_mm
etp_lora_mv
etp_mm
etp_mv
etp_sp_mm
etp_sp_mv
D-RTRL init rules (ETP_RULES_INIT_DRTRL):
etp_conv
etp_elemwise
etp_lora_mm
etp_lora_mv
etp_mm
etp_mv
etp_sp_mm
etp_sp_mv
pp_prop init rules (ETP_RULES_INIT_PP):
etp_conv
etp_elemwise
etp_lora_mm
etp_lora_mv
etp_mm
etp_mv
etp_sp_mm
etp_sp_mv
Adding a Custom Primitive#
Adding a new ETP primitive takes only a few steps. Here we create a scaled matrix multiplication with an optional bias as an example:
The example exercises the whole dict rule API: both the weight and bias branches are wired end-to-end.
import braintrace
from braintrace import register_primitive
# Step 1: Define the implementation.
# Plain JAX function — no special annotations needed.
def _scaled_matmul_impl(*args, scale=1.0, has_bias=False):
x, w = args[0], args[1]
y = scale * (x @ w)
if has_bias:
y = y + args[2]
return y
# Step 2: Register as an ETP primitive.
# register_primitive() returns an ``ETPPrimitive`` and auto-derives all
# standard JAX rules (abstract_eval, lowering, JVP, transpose, batching).
scaled_mm_p = register_primitive('etp_scaled_mm', _scaled_matmul_impl, batched=True)
print("Primitive registered:", scaled_mm_p)
print("Type:", type(scaled_mm_p).__name__)
Primitive registered: etp_scaled_mm
Type: ETPPrimitive
# Step 3: Register the four ETP-specific rules (dict API).
# Each rule accepts / returns ``Dict[str, Array]`` keyed by the names
# in ``trainable_invars_fn`` — here ``'weight'`` and (optionally) ``'bias'``.
def _scaled_trainable_invars(params):
"""Tell the compiler which invars are trainable."""
base = {'weight': 1}
if params.get('has_bias', False):
base['bias'] = 2
return base
def _scaled_yw_to_w(hidden_dim, trace, *, scale=1.0, has_bias=False):
# y = scale * x @ w + b
# -> ∂y/∂w along the "out" axis is scaled by `scale`; the y→w chain
# link is still elementwise along `out` axis (singleton at axis=-2).
out = {'weight': trace['weight'] * jnp.expand_dims(hidden_dim, axis=-2) * scale}
if has_bias:
out['bias'] = trace['bias'] * hidden_dim
return out
def _scaled_xy_to_dw(x, hidden_dim, weights, *, scale=1.0, has_bias=False):
# Single fused VJP over a dict-valued forward function — returns
# gradients for both 'weight' and 'bias' in one pass.
def _fwd(w_dict):
y = scale * (x @ w_dict['weight'])
if has_bias:
y = y + w_dict['bias']
return y
_, vjp_fn = jax.vjp(_fwd, weights)
return vjp_fn(hidden_dim)[0]
def _scaled_init_drtrl(x_var, y_var, weight_vars, num_hidden_state):
"""D-RTRL parameter-dim trace: one leaf per trainable key."""
batch = x_var.aval.shape[0]
out = {
'weight': jnp.zeros(
(batch, *weight_vars['weight'].aval.shape, num_hidden_state)
)
}
if 'bias' in weight_vars:
out['bias'] = jnp.zeros(
(batch, *weight_vars['bias'].aval.shape, num_hidden_state)
)
return out
def _scaled_init_pp(x_var, y_var, weight_vars, num_hidden_state):
"""pp-prop df trace: single array shaped like the output."""
return jnp.zeros(
(*y_var.aval.shape, num_hidden_state),
dtype=y_var.aval.dtype,
)
scaled_mm_p.register_etp_rules(
yw_to_w=_scaled_yw_to_w,
xy_to_dw=_scaled_xy_to_dw,
init_drtrl=_scaled_init_drtrl,
init_pp=_scaled_init_pp,
)
# Shorthand: every ``register_*`` method also exists as a standalone call,
# and ``ETPPrimitiveSpec`` (below) bundles the whole thing in one record.
print("yw_to_w registered: ", scaled_mm_p in ETP_RULES_YW_TO_W)
print("xy_to_dw registered: ", scaled_mm_p in ETP_RULES_XY_TO_DW)
print("init_drtrl registered:", scaled_mm_p in ETP_RULES_INIT_DRTRL)
print("init_pp registered: ", scaled_mm_p in ETP_RULES_INIT_PP)
yw_to_w registered: True
xy_to_dw registered: True
init_drtrl registered: True
init_pp registered: True
# Step 4: Use the custom primitive via ``primitive.bind()``.
x = jnp.ones((4, 3))
w = jnp.ones((3, 5))
y = scaled_mm_p.bind(x, w, scale=2.0, has_bias=False)
y_expected = 2.0 * (x @ w)
print("Output shape :", y.shape)
print("Matches 2·xw :", bool(jnp.allclose(y, y_expected)))
# With bias:
b = jnp.full((5,), 0.1)
y_bias = scaled_mm_p.bind(x, w, b, scale=2.0, has_bias=True)
print("With bias :", y_bias[0])
Output shape : (4, 5)
Matches 2·xw : True
With bias : [6.1 6.1 6.1 6.1 6.1]
# All JAX transformations work automatically for the custom primitive.
# JIT
y_jit = jax.jit(lambda x, w: scaled_mm_p.bind(x, w, scale=2.0, has_bias=False))(x, w)
print("JIT works:", bool(jnp.allclose(y_jit, y_expected)))
# Grad
dw = jax.grad(lambda w: jnp.sum(scaled_mm_p.bind(x, w, scale=2.0, has_bias=False)))(w)
print("Grad shape:", dw.shape)
# Vmap
xs = jnp.ones((8, 4, 3))
ys = jax.vmap(lambda xi: scaled_mm_p.bind(xi, w, scale=2.0, has_bias=False))(xs)
print("Vmap output shape:", ys.shape)
JIT works: True
Grad shape: (3, 5)
Vmap output shape: (8, 4, 5)
Compiler integration. The class-style registration above is enough for direct
primitive.bind()use, JIT, grad, vmap, and JVP. For the primitive to be discovered by the ETP compiler (compile_etrace_graph,D_RTRL,ES_D_RTRL), you must also publish thetrainable_invars_fnand the invar layout via anETPPrimitiveSpec— see the next section._scaled_trainable_invarsdefined above is re-used there.
Spec-based Registration#
When a primitive is intended for the ETP compiler, use the spec form. An ETPPrimitiveSpec bundles the implementation, all four rules, the invar/outvar layout, and the trainable_invars_fn into one frozen dataclass. Passing it to register_primitive_spec wires everything up and records the spec in ETP_PRIMITIVE_SPECS so the compiler can query it via get_primitive_spec.
Spec fields:
Field |
Purpose |
|---|---|
|
Primitive name |
|
Plain JAX forward function |
|
The four ETP rules |
|
|
|
Position of the non-trainable input, or |
|
Position of |
|
Batched-input primitive? |
|
Compiler traverses this primitive when walking |
The spec form is equivalent to the class-based form — pick whichever style fits your codebase.
import braintrace
def _spec_impl(*args, scale=1.0, has_bias=False):
x, w = args[0], args[1]
y = scale * (x @ w)
if has_bias:
y = y + args[2]
return y
def _spec_trainable_invars(params):
base = {'weight': 1}
if params.get('has_bias', False):
base['bias'] = 2
return base
def _spec_yw_to_w(hidden_dim, trace, *, scale=1.0, has_bias=False):
out = {'weight': trace['weight'] * jnp.expand_dims(hidden_dim, axis=-2) * scale}
if has_bias:
out['bias'] = trace['bias'] * hidden_dim
return out
def _spec_xy_to_dw(x, hidden_dim, weights, *, scale=1.0, has_bias=False):
def _fwd(w_dict):
y = scale * (x @ w_dict['weight'])
if has_bias:
y = y + w_dict['bias']
return y
_, vjp_fn = jax.vjp(_fwd, weights)
return vjp_fn(hidden_dim)[0]
def _spec_init_drtrl(x_var, y_var, weight_vars, n):
batch = x_var.aval.shape[0]
out = {
'weight': jnp.zeros(
(batch, *weight_vars['weight'].aval.shape, n)
)
}
if 'bias' in weight_vars:
out['bias'] = jnp.zeros(
(batch, *weight_vars['bias'].aval.shape, n)
)
return out
def _spec_init_pp(x_var, y_var, weight_vars, n):
return jnp.zeros((*y_var.aval.shape, n), dtype=y_var.aval.dtype)
spec = braintrace.ETPPrimitiveSpec(
name='etp_spec_demo',
impl=_spec_impl,
yw_to_w=_spec_yw_to_w,
xy_to_dw=_spec_xy_to_dw,
init_drtrl=_spec_init_drtrl,
init_pp=_spec_init_pp,
trainable_invars_fn=_spec_trainable_invars,
x_invar_index=0, # ``invars[0]`` is the input; trainable invars are
# fully described by ``trainable_invars_fn`` above.
batched=True,
)
spec_p = braintrace.register_primitive_spec(spec)
# The spec is recoverable from the primitive at any time:
assert braintrace.get_primitive_spec(spec_p) is spec
# Quick sanity check of the invar layout:
print("trainable (no bias):", spec.resolve_trainable_invars({'has_bias': False}))
print("trainable (bias) :", spec.resolve_trainable_invars({'has_bias': True}))
# And the primitive binds normally:
y_spec = spec_p.bind(x, w, scale=3.0, has_bias=False)
print("spec bind output :", y_spec.shape)
trainable (no bias): {'weight': 1}
trainable (bias) : {'weight': 1, 'bias': 2}
spec bind output : (4, 5)
The gradient_enabled Flag#
register_primitive() accepts a gradient_enabled keyword (default False). It controls how the compiler treats this primitive when walking from a weight’s output back to a hidden state.
|
Compiler behaviour |
Example |
|---|---|---|
|
Treats the primitive as a tail boundary. A preceding ETP weight whose only path to a hidden state passes through this primitive is excluded from ETP, because per-primitive ETP rules cannot express weight-then-weight composition. |
All trainable matmul/conv/sparse/LoRA primitives use this. |
|
The primitive is identity-like and may sit on the tail of the |
Only |
Use gradient_enabled=True only when the primitive’s xy_to_dw rule is itself an identity-like passthrough; mark all genuinely trainable ops with the default. The “weight -> weight -> hidden” exclusion is what makes per-primitive ETP rules sound – see advanced/limitations.ipynb for a worked example with GRUCell (3 Linears, only 2 ETP relations).
Integrating a Primitive with Online Learning#
Marking a weight operation with a braintrace.* primitive is the only thing a model has to do to opt that parameter into online learning. The compiler then walks the jaxpr, finds every ETP primitive, connects it to the downstream hidden states, and builds the eligibility-trace machinery for either D_RTRL (parameter-dim trace) or ES_D_RTRL / pp_prop (IO-dim trace).
Rule of thumb
Goal |
Use |
|---|---|
Include a parameter in online learning |
|
Exclude a parameter from online learning |
regular JAX op: |
The short example below wires a vanilla RNN into D_RTRL: only the recurrent weight is marked with braintrace.matmul, so only it receives an eligibility trace. The input weight uses plain @ and is learned by BPTT through the unrolled scan.
import brainstate
class TinyRNN(brainstate.nn.Module):
def __init__(self, in_dim=4, hid_dim=6):
super().__init__()
self.in_dim = in_dim
self.hid_dim = hid_dim
# Recurrent weight: ETP-enabled (online learning via D-RTRL).
self.W_rec = brainstate.ParamState(
0.1 * jax.random.normal(jax.random.PRNGKey(0), (hid_dim, hid_dim))
)
# Input weight: plain matmul, learned via BPTT instead.
self.W_in = brainstate.ParamState(
0.1 * jax.random.normal(jax.random.PRNGKey(1), (in_dim, hid_dim))
)
def init_state(self, batch_size=None, **kwargs):
# ``HiddenState`` is what the ETP compiler traces through.
self.h = brainstate.HiddenState(
jnp.zeros((batch_size or 1, self.hid_dim))
)
def update(self, x):
# W_in is NOT marked -> excluded from ETP.
input_drive = x @ self.W_in.value
# W_rec IS marked -> included in ETP.
rec_drive = braintrace.matmul(self.h.value, self.W_rec.value)
self.h.value = jax.nn.tanh(input_drive + rec_drive)
return self.h.value
model = TinyRNN(in_dim=4, hid_dim=6)
brainstate.nn.init_all_states(model, batch_size=2)
# Wrap the model in a D-RTRL algorithm and compile the ETP graph.
alg = braintrace.D_RTRL(model)
alg.compile_graph(jnp.zeros((2, model.in_dim)))
print("Compiled ETP relations:", len(alg.graph.hidden_param_op_relations))
for rel in alg.graph.hidden_param_op_relations:
print(" primitive =", rel.primitive.name,
" trainable keys =", list(rel.trainable_vars.keys()))
Compiled ETP relations: 1
primitive = etp_mm trainable keys = ['weight']
Summary#
ETP primitives provide a clean, extensible foundation for online learning in recurrent networks:
8 built-in primitives cover the most common use cases: dense matmul (mm/mv), element-wise ops, convolution, sparse matmul (mm/mv), and LoRA matmul (mm/mv).
Dict rule API — every primitive declares its full set of trainable inputs via
trainable_invars_fn, and the four ETP rules consume and returnDict[str, Array]. A single primitive can own severalParamStateobjects (e.g. weight + bias, or \(B + A + b\) in LoRA) and the executor routes gradients to each in one pass.Custom primitives can be added in a few dozen lines: implement the forward function, call
register_primitive(class style) or build anETPPrimitiveSpec(compiler-ready), then hand-write the four ETP rules.All JAX transformations (JIT, grad, vmap, JVP) work automatically — only the four online-learning-specific rules need hand-writing.
Parameter selection is primitive-based — every
brainstate.ParamStateis eligible for ETP, and participation depends only on whether abraintrace.*ETP primitive consumed it. Usegradient_enabled=Trueexclusively for identity-like ops such asetp_elemwise_p.Saiunit quantities are handled transparently by every user-facing wrapper.
Where to look for the math:
Rule |
Algorithm term |
Source with derivation |
|---|---|---|
|
\(\operatorname{diag}(\mathbf{D}_f^t) \otimes \mathbf{x}^t\) |
docstrings in |
|
\(\mathbf{D}^t \boldsymbol{\epsilon}^{t-1}\) (\(y \to W\) link) |
same files |
|
param-dim trace shape |
same files |
|
output-dim df-trace shape |
same files |
Further reading: advanced/limitations.ipynb explains the non-parametric-tail invariant and walks through GRUCell (3 Linears, only 2 ETP relations).