braintrace documentation#
braintrace implements scalable online learning for recurrent neural networks (RNNs) and spiking neural networks (SNNs) using eligibility trace propagation (ETP).
The key idea: mark weight operations with ETP primitives (braintrace.matmul, braintrace.conv, etc.) to include them in online learning. Regular JAX operations are automatically excluded — no special parameter classes needed.
Basic Usage#
import braintrace
import brainstate
class MyRNN(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.rnn = braintrace.nn.GRUCell(10, 64)
self.out = braintrace.nn.Linear(64, 10)
def update(self, x):
return self.out(self.rnn(x))
model = MyRNN()
model.init_all_states()
# Wrap with an online learning algorithm (just 2 lines)
trainer = braintrace.D_RTRL(model)
trainer.compile_graph(example_input)
# Now use brainstate.transform.grad as usual — gradients are
# computed online via eligibility traces, not BPTT.
Installation#
pip install -U braintrace[cpu]
pip install -U braintrace[cuda12]
pip install -U braintrace[tpu]
See also the ecosystem#
braintrace is part of the brain simulation ecosystem.