braintrace documentation

braintrace documentation#

braintrace implements scalable online learning for recurrent neural networks (RNNs) and spiking neural networks (SNNs) using eligibility trace propagation (ETP).

The key idea: mark weight operations with ETP primitives (braintrace.matmul, braintrace.conv, etc.) to include them in online learning. Regular JAX operations are automatically excluded — no special parameter classes needed.


Basic Usage#

import braintrace
import brainstate

class MyRNN(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = braintrace.nn.GRUCell(10, 64)
        self.out = braintrace.nn.Linear(64, 10)

    def update(self, x):
        return self.out(self.rnn(x))

model = MyRNN()
model.init_all_states()

# Wrap with an online learning algorithm (just 2 lines)
trainer = braintrace.D_RTRL(model)
trainer.compile_graph(example_input)

# Now use brainstate.transform.grad as usual — gradients are
# computed online via eligibility traces, not BPTT.

Installation#

pip install -U braintrace[cpu]
pip install -U braintrace[cuda12]
pip install -U braintrace[tpu]

See also the ecosystem#

braintrace is part of the brain simulation ecosystem.