Core Concepts#

ETP Primitives (User API)#

These functions mark weight operations for inclusion in online learning. Use braintrace.matmul(x, w) instead of x @ w to include a weight in eligibility trace computation. Parameters used with regular JAX ops are automatically excluded — no special parameter classes needed.

matmul

ETP-aware matrix multiplication.

element_wise

ETP-aware element-wise operation.

conv

ETP-aware convolution.

sparse_matmul

ETP-aware sparse matrix multiplication.

lora_matmul

ETP-aware LoRA (Low-Rank Adaptation) matrix multiplication.

Controlling Parameter Participation#

import braintrace
import brainstate

class MyRNN(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.w_rec = brainstate.ParamState(...)   # want ETP
        self.w_in = brainstate.ParamState(...)     # do NOT want ETP
        self.h = brainstate.ShortTermState(...)

    def update(self, x):
        # regular matmul -> w_in excluded from ETP
        inp = x @ self.w_in.value
        # ETP matmul -> w_rec included in ETP
        self.h.value = jax.nn.tanh(inp + braintrace.matmul(self.h.value, self.w_rec.value))
        return self.h.value
Table 1 Parameter Selection Rules#

Goal

How

Include parameter in online learning

Use a braintrace.* ETP primitive (e.g. braintrace.matmul(x, w))

Exclude parameter from online learning

Use a regular JAX op (e.g. x @ w)

Selection mechanism

Operation primitive type — not parameter class type. Every brainstate.ParamState is eligible; participation depends solely on whether an ETP primitive consumed it.

Input Data#

Wrappers that tell the online learning algorithm whether the input is a single time step or a sequence of time steps.

SingleStepData

The data at a single time step.

MultiStepData

The data at multiple time steps.

Eligibility Trace State#

EligibilityTrace

The state for storing the eligibility trace during the computation of online learning algorithms.

Gradient Utilities#

GradExpon

Accumulates gradients exponentially.

Errors#

NotSupportedError

Exception raised for operations that are not supported.

CompilationError

Exception raised for errors that occur during the compilation process.