GradExpon#
- class braintrace.GradExpon(grad_shape, tau_or_decay)[source]#
Accumulates gradients exponentially.
Mathematically, the update rule is:
\[\begin{split} g_{t+1} = \text{decay} \cdot g_t + \text{grads} \\ \end{split}\]where \(g_t\) is the accumulated gradient at time \(t\), \(\text{grads}\) is the gradient at time \(t\), and \(\text{decay}\) is the decay factor.
- Parameters:
grad_shape (
PyTree) – The shape of the gradients.tau_or_decay (
Quantity[s]|float) – The decay time constant or the decay factor.
- update(grads)[source]#
Updates the accumulated gradients using the exponential decay rule.
This method applies the update rule g_{t+1} = decay * g_t + grads, where g_t is the accumulated gradient at time t, grads is the new gradient, and decay is the decay factor.
- Parameters:
grads (
PyTree) – The new gradients to be incorporated into the accumulated gradients.- Returns:
None. The method updates the self.gradients attribute in-place.