ParamDimVjpAlgorithm#
- class braintrace.ParamDimVjpAlgorithm(model, name=None, vjp_method='single-step', fast_solve=True, normalize_matrix_spectrum=False, **kwargs)[source]#
The online gradient computation algorithm with the diagonal approximation and the parameter dimension complexity.
This algorithm computes the gradients of the weights with the diagonal approximation and the parameter dimension complexity. Its algorithm is based on the RTRL algorithm, and has the following learning rule:
\[\begin{split} \begin{aligned} &\boldsymbol{\epsilon}^t \approx \mathbf{D}^t \boldsymbol{\epsilon}^{t-1}+\operatorname{diag}\left(\mathbf{D}_f^t\right) \otimes \mathbf{x}^t \\ & \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}} \end{aligned} \end{split}\]For more details, please see the D-RTRL algorithm presented in our manuscript.
Note than the
ParamDimVjpAlgorithmis a subclass ofbrainstate.nn.Module, and it is sensitive to the context/mode of the computation. Particularly, theParamDimVjpAlgorithmis sensitive tobrainstate.mixin.Batchingbehavior.This algorithm has the \(O(B\theta)\) memory complexity, where \(\theta\) is the number of parameters, and \(B\) the batch size.
For a convolutional layer, the algorithm computes the weight gradients with the \(O(B\theta)\) memory complexity, where \(\theta\) is the dimension of the convolutional kernel.
For a Linear transformation layer, the algorithm computes the weight gradients with the \(O(BIO)`\) computational complexity, where \(I\) and \(O\) are the number of input and output dimensions.
- Parameters:
model (
Module) – The model function, which receives the input arguments and returns the model output.vjp_method (
str) –The method for computing the VJP. It should be either “single-step” or “multi-step”.
”single-step”: The VJP is computed at the current time step, i.e., \(\partial L^t/\partial h^t\).
”multi-step”: The VJP is computed at multiple time steps, i.e., \(\partial L^t/\partial h^{t-k}\), where \(k\) is determined by the data input.
mode (braintrace.mixin.Mode) – The computing mode, indicating the batching behavior.
- get_etrace_of(weight)[source]#
Get the eligibility trace of the given weight.
The eligibility trace contains the following structures:
- Return type: