GradExpon

GradExpon#

class braintrace.GradExpon(grad_shape, tau_or_decay)[source]#

Accumulates gradients exponentially.

Mathematically, the update rule is:

\[\begin{split} g_{t+1} = \text{decay} \cdot g_t + \text{grads} \\ \end{split}\]

where \(g_t\) is the accumulated gradient at time \(t\), \(\text{grads}\) is the gradient at time \(t\), and \(\text{decay}\) is the decay factor.

Parameters:
  • grad_shape (PyTree) – The shape of the gradients.

  • tau_or_decay (Quantity[s] | float) – The decay time constant or the decay factor.

update(grads)[source]#

Updates the accumulated gradients using the exponential decay rule.

This method applies the update rule g_{t+1} = decay * g_t + grads, where g_t is the accumulated gradient at time t, grads is the new gradient, and decay is the decay factor.

Parameters:

grads (PyTree) – The new gradients to be incorporated into the accumulated gradients.

Returns:

None. The method updates the self.gradients attribute in-place.