IODimVjpAlgorithm#
- class braintrace.IODimVjpAlgorithm(model, decay_or_rank, name=None, vjp_method='single-step', fast_solve=True, **kwargs)#
The online gradient computation algorithm with the diagonal approximation and the input-output dimensional complexity.
This algrithm computes the gradients of the weights with the diagonal approximation and the input-output dimensional complexity. Its aglritm is based on the RTRL algorithm, and has the following learning rule:
\[\begin{split} \begin{aligned} & \boldsymbol{\epsilon}^t \approx \boldsymbol{\epsilon}_{\mathbf{f}}^t \otimes \boldsymbol{\epsilon}_{\mathbf{x}}^t \\ & \boldsymbol{\epsilon}_{\mathbf{x}}^t=\alpha \boldsymbol{\epsilon}_{\mathbf{x}}^{t-1}+\mathbf{x}^t \\ & \boldsymbol{\epsilon}_{\mathbf{f}}^t=\alpha \operatorname{diag}\left(\mathbf{D}^t\right) \circ \boldsymbol{\epsilon}_{\mathbf{f}}^{t-1}+(1-\alpha) \operatorname{diag}\left(\mathbf{D}_f^t\right) \\ & \nabla_{\boldsymbol{\theta}} \mathcal{L}=\sum_{t^{\prime} \in \mathcal{T}} \frac{\partial \mathcal{L}^{t^{\prime}}}{\partial \mathbf{h}^{t^{\prime}}} \circ \boldsymbol{\epsilon}^{t^{\prime}} \end{aligned} \end{split}\]For more details, please see the ES-D-RTRL algorithm presented in our manuscript.
This algorithm has the \(O(BI+BO)\) memory complexity and \(O(BIO)\) computational complexity, where \(I\) and \(O\) are the number of input and output dimensions, and \(B\) the batch size.
Particularly, for a Linear transformation layer, the algorithm computes the weight gradients with the \(O(Bn)\) memory complexity and \(O(Bn^2)\) computational complexity, where \(n\) is the number of hidden dimensions.
- Parameters:
model (
Module) – The model function, which receives the input arguments and returns the model output.vjp_method (
str) –The method for computing the VJP. It should be either “single-step” or “multi-step”.
”single-step”: The VJP is computed at the current time step, i.e., \(\partial L^t/\partial h^t\).
”multi-step”: The VJP is computed at multiple time steps, i.e., \(\partial L^t/\partial h^{t-k}\), where \(k\) is determined by the data input.
decay_or_rank (
float|int) – The exponential smoothing factor for the eligibility trace. If it is a float, it is the decay factor, should be in the range of (0, 1). If it is an integer, it is the number of approximation rank for the algorithm, should be greater than 0.mode (braintrace.mixin.Mode) – The computing mode, indicating the batching information.
- get_etrace_of(weight)[source]#
Get the eligibility trace of the given weight.
The eligibility trace contains the following structures:
- init_etrace_state(*args, **kwargs)[source]#
Initialize the eligibility trace states of the etrace algorithm.
This method is needed after compiling the etrace graph. See
compile_graph()for the details.