saiunit.autograd.vector_grad#
- saiunit.autograd.vector_grad(func, argnums=0, return_value=False, has_aux=False, unit_aware=True)[source]#
Unit-aware element-wise gradient of a vector-valued function.
Unlike
grad()(which requires scalar outputs),vector_gradcomputes element-wise gradients for vector-valued functions by using a VJP with an all-ones tangent vector. This is equivalent to the diagonal of the Jacobian when the output has the same shape as the input.- Parameters:
func (
Callable) – A Python callable that computes a vector output given arguments (possibly carrying physical units).argnums (
int|Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is0.return_value (
bool) – IfTrue, the returned function also returns the function value. Default isFalse.has_aux (
bool) – IfTrue,funcis expected to return(output, aux)where onlyoutputis differentiated. Default isFalse.unit_aware (
bool) – IfTrue, physical units are propagated through the differentiation. Default isTrue.
- Returns:
grad_fun – A function with the same signature as
functhat returns the element-wise gradient. The exact return shape depends onreturn_valueandhas_aux:Default:
gradientreturn_value=True:(gradient, value)has_aux=True:(gradient, aux)Both:
(gradient, value, aux)
- Return type:
callable
Notes
When
unit_aware=True,funcmust return a single array (not a pytree with multiple leaves).Examples
Element-wise gradient of a squared function with units:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 >>> vg_fn = suauto.vector_grad(f) >>> vg_fn(jnp.array([3.0, 4.0]) * u.ms) [6.0, 8.0] * ms
Returning both the gradient and the function value:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 >>> vg_fn = suauto.vector_grad(f, return_value=True) >>> grad, value = vg_fn(jnp.array([3.0, 4.0]) * u.ms) >>> grad [6.0, 8.0] * ms >>> value [9.0, 16.0] * ms ** 2