OSTLRecurrent#
- class braintrace.OSTLRecurrent(model, name=None, vjp_method='single-step', fast_solve=True, normalize_matrix_spectrum=False, trace_dtype=None, **kwargs)#
OSTL ‘with-H’ regime — RTRL-exact single-layer factorization.
OSTL derives an online rule by cleanly separating the gradient into a temporal eligibility trace and a spatial learning signal. The ‘with-H’ regime retains the hidden-to-hidden Jacobian, so the trace carries the full temporal term and the rule is gradient-equivalent to BPTT for a single recurrent layer:
\[\boldsymbol{\epsilon}^t = \mathbf{D}^t\,\boldsymbol{\epsilon}^{t-1} + \operatorname{diag}(\mathbf{D}_f^t)\otimes \mathbf{x}^t , \qquad \nabla_{\boldsymbol{\theta}}\mathcal{L} = \sum_t \frac{\partial \mathcal{L}^t}{\partial \mathbf{h}^t} \circ \boldsymbol{\epsilon}^t ,\]where \(\mathbf{D}^t\) is the hidden-to-hidden Jacobian, \(\mathbf{D}_f^t\) the state-to-output Jacobian, and \(\mathbf{x}^t\) the presynaptic input. This is exactly the per-parameter D-RTRL trace (memory \(O(P\cdot H)\)), so the class delegates entirely to
ParamDimVjpAlgorithm.- Parameters:
model (
Module) – The recurrent SNN whose weights are trained online.name (
str|None) – Forwarded verbatim toParamDimVjpAlgorithm.vjp_method (
str) – Forwarded verbatim toParamDimVjpAlgorithm.fast_solve (
bool) – Forwarded verbatim toParamDimVjpAlgorithm.normalize_matrix_spectrum (
bool) – Forwarded verbatim toParamDimVjpAlgorithm.
Examples
>>> import brainstate >>> import braintrace >>> >>> class Net(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.cell = braintrace.nn.ValinaRNNCell(1, 20, activation='tanh') ... self.out = braintrace.nn.Linear(20, 1) ... def update(self, x): ... return x >> self.cell >> self.out >>> >>> model = Net() >>> _ = brainstate.nn.init_all_states(model) >>> learner = braintrace.OSTLRecurrent(model) >>> x0 = brainstate.random.randn(1) >>> learner.compile_graph(x0) >>> y = learner(x0)
References