Gradient Computations#

Automatic differentiation transformations for computing gradients, Jacobians, and Hessians. These functions extend JAX’s autodiff capabilities with support for stateful computations, making them ideal for training neural networks and optimizing complex dynamical systems.

Gradient Transformations#

grad([fun, grad_states, argnums, ...])

Compute the gradient of a scalar-valued function with respect to its arguments.

vector_grad([func, grad_states, argnums, ...])

Take vector-valued gradients for function func.

fwd_grad([func, grad_states, argnums, ...])

Take forward first-order gradients for function func.

Vector-Jacobian and Jacobian-Vector Products#

vjp(fun, *primals[, grad_states, argnums, ...])

Compute a state-aware vector-Jacobian product (reverse-mode autodiff).

jvp(fun, primals, tangents, *[, has_aux])

Compute a state-aware Jacobian-vector product (forward-mode autodiff).

Jacobian and Hessian#

jacrev(fun[, grad_states, argnums, has_aux, ...])

Extending automatic Jacobian (reverse-mode) of func to classes.

jacfwd(func[, grad_states, argnums, ...])

Extending automatic Jacobian (forward-mode) of func to classes.

jacobian(fun[, grad_states, argnums, ...])

Extending automatic Jacobian (reverse-mode) of func to classes.

hessian(func[, grad_states, argnums, ...])

Hessian of func as a dense array.

Base Classes#

GradientTransform

Automatic Differentiation Transformations for the State system.