Core Concepts#
ETP Primitives (User API)#
These functions mark weight operations for inclusion in online learning.
Use braintrace.matmul(x, w) instead of x @ w to include a weight
in eligibility trace computation. Parameters used with regular JAX ops
are automatically excluded — no special parameter classes needed.
ETP-aware matrix multiplication. |
|
ETP-aware element-wise operation. |
|
ETP-aware convolution. |
|
ETP-aware sparse matrix multiplication. |
|
ETP-aware LoRA (Low-Rank Adaptation) matrix multiplication. |
Controlling Parameter Participation#
import braintrace
import brainstate
class MyRNN(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.w_rec = brainstate.ParamState(...) # want ETP
self.w_in = brainstate.ParamState(...) # do NOT want ETP
self.h = brainstate.ShortTermState(...)
def update(self, x):
# regular matmul -> w_in excluded from ETP
inp = x @ self.w_in.value
# ETP matmul -> w_rec included in ETP
self.h.value = jax.nn.tanh(inp + braintrace.matmul(self.h.value, self.w_rec.value))
return self.h.value
Goal |
How |
|---|---|
Include parameter in online learning |
Use a |
Exclude parameter from online learning |
Use a regular JAX op (e.g. |
Selection mechanism |
Operation primitive type — not parameter class type. Every
|
Input Data#
Wrappers that tell the online learning algorithm whether the input is a single time step or a sequence of time steps.
The data at a single time step. |
|
The data at multiple time steps. |
Eligibility Trace State#
The state for storing the eligibility trace during the computation of online learning algorithms. |
Gradient Utilities#
Accumulates gradients exponentially. |
Errors#
Exception raised for operations that are not supported. |
|
Exception raised for errors that occur during the compilation process. |