saiunit.autograd.grad#
- saiunit.autograd.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#
Physical unit-aware version of jax.grad.
Computes the gradient of
funwhile correctly propagating physical units through the differentiation.- Parameters:
fun (
Callable) – A Python callable that computes a scalar loss given arguments. The output must be a scalar (possibly with physical units).argnums (
int|Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is0.has_aux (
bool) – IfTrue,funis expected to return a pair(loss, aux)where onlylossis differentiated. The returned function produces(gradient, aux). Default isFalse.holomorphic (
bool) – Whether to use holomorphic differentiation (for complex-valued functions). Default isFalse.allow_int (
bool) – Whether to allow differentiation with respect to integer-valued inputs. Default isFalse.
- Returns:
grad_fun – A function with the same signature as
funthat returns the gradient. Ifhas_aux=True, it returns(gradient, aux)instead. Gradients carry the correct physical units derived from the output and input units.- Return type:
Examples
Compute the gradient of a scalar function with units:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 >>> grad_fn = suauto.grad(f) >>> grad_fn(jnp.array(3.0) * u.ms) 6.0 * ms
Gradient with auxiliary data:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f_aux(x): ... return x ** 2, x * 3 >>> grad_fn = suauto.grad(f_aux, has_aux=True) >>> g, aux = grad_fn(jnp.array(3.0) * u.mV) >>> g 6.0 * mvolt >>> aux 9.0 * mvolt