saiunit.autograd.vector_grad

Contents

saiunit.autograd.vector_grad#

saiunit.autograd.vector_grad(func, argnums=0, return_value=False, has_aux=False, unit_aware=True)[source]#

Unit-aware element-wise gradient of a vector-valued function.

Unlike grad() (which requires scalar outputs), vector_grad computes element-wise gradients for vector-valued functions by using a VJP with an all-ones tangent vector. This is equivalent to the diagonal of the Jacobian when the output has the same shape as the input.

Parameters:
  • func (Callable) – A Python callable that computes a vector output given arguments (possibly carrying physical units).

  • argnums (int | Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is 0.

  • return_value (bool) – If True, the returned function also returns the function value. Default is False.

  • has_aux (bool) – If True, func is expected to return (output, aux) where only output is differentiated. Default is False.

  • unit_aware (bool) – If True, physical units are propagated through the differentiation. Default is True.

Returns:

grad_fun – A function with the same signature as func that returns the element-wise gradient. The exact return shape depends on return_value and has_aux:

  • Default: gradient

  • return_value=True: (gradient, value)

  • has_aux=True: (gradient, aux)

  • Both: (gradient, value, aux)

Return type:

callable

Notes

When unit_aware=True, func must return a single array (not a pytree with multiple leaves).

See also

grad

Gradient for scalar-valued functions.

jacrev

Full Jacobian via reverse-mode AD.

Examples

Element-wise gradient of a squared function with units:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def f(x):
...     return x ** 2
>>> vg_fn = suauto.vector_grad(f)
>>> vg_fn(jnp.array([3.0, 4.0]) * u.ms)
[6.0, 8.0] * ms

Returning both the gradient and the function value:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def f(x):
...     return x ** 2
>>> vg_fn = suauto.vector_grad(f, return_value=True)
>>> grad, value = vg_fn(jnp.array([3.0, 4.0]) * u.ms)
>>> grad
[6.0, 8.0] * ms
>>> value
[9.0, 16.0] * ms ** 2