ETraceVjpAlgorithm

ETraceVjpAlgorithm#

class braintrace.ETraceVjpAlgorithm(model, name=None, vjp_method='single-step')#

The base class for the eligibility trace algorithm which supporting the VJP gradient computation (reverse-mode differentiation).

The term VJP comes from the following two aspects:

First, this module is designed to be compatible with the JAX’s VJP mechanism. This means that the gradient is computed according to the reverse-mode differentiation interface, like the jax.grad() function, the jax.vjp() function, or the jax.jacrev() function. The true update function is defined as a custom VJP function ._true_update_fun(), which receives the inputs, the hidden states, other states, and etrace variables at the last time step, and returns the outputs, the hidden states, other states, and etrace variables at the current time step.

For each subclass (or the instance of an etrace algorithm), we should define the following methods:

  • ._update(): update the eligibility trace states and return the outputs, hidden states, other states, and etrace data.

  • ._update_fwd(): the forward pass of the custom VJP rule.

  • ._update_bwd(): the backward pass of the custom VJP rule.

However, this class has provided a default implementation for the ._update(), ._update_fwd(), and ._update_bwd() methods.

To implement a new etrace algorithm, users just need to override the following methods:

  • ._solve_weight_gradients(): solve the gradients of the learnable weights / parameters.

  • ._update_etrace_data(): update the eligibility trace data.

  • ._assign_etrace_data(): assign the eligibility trace data to the states.

  • ._get_etrace_data(): get the eligibility trace data.

Second, the algorithm computes the spatial gradient \(\partial L^t / \partial H^t\) using the standard back-propagation algorithm. This design can enhance the accuracy and the stability of the algorithm for computing gradients.

Parameters:
  • model (Module) – The model function, which receives the input arguments and returns the model output.

  • name (str | None) – The name of the etrace algorithm.

  • vjp_method (str) –

    The method for computing the VJP. It should be either “single-step” or “multi-step”.

    • ”single-step”: The VJP is computed at the current time step, i.e., \(\partial L^t/\partial h^t\).

    • ”multi-step”: The VJP is computed at multiple time steps, i.e., \(\partial L^t/\partial h^{t-k}\), where \(k\) is determined by the data input.

update(*args)[source]#

Update the model states and the eligibility trace.

The input arguments args here supports very complex data structures, including the combination of SingleStepData and MultiStepData.

  • SingleStepData: indicating the data at the single time step, \(x_t\).

  • MultiStepData: indicating the data at multiple time steps, \([x_{t-k}, ..., x_t]\).

Suppose all inputs have the shape of (10,).

If the input arguments are given by:

x = [jnp.ones((10,)), jnp.zeros((10,))]

Then, two input arguments are considered as the SingleStepData.

If the input arguments are given by:

x = [braintrace.SingleStepData(jnp.ones((10,))),
     braintrace.SingleStepData(jnp.zeros((10,)))]

This is the same as the previous case, they are all considered as the input at the current time step.

If the input arguments are given by:

x = [braintrace.MultiStepData(jnp.ones((5, 10)),
     jnp.zeros((10,)))]

or,

x = [braintrace.MultiStepData(jnp.ones((5, 10)),
     braintrace.SingleStepData(jnp.zeros((10,)))]

Then, the first input argument is considered as the MultiStepData, and its data will be fed into the model within five consecutive steps, and the second input argument will be fed into the model at each time of this five consecutive steps.

Parameters:

*args – the input arguments.

Return type:

Any