brainunit.autograd.value_and_grad#
- brainunit.autograd.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#
Physical unit-aware version of jax.value_and_grad.
Computes both the value and gradient of
funwhile correctly propagating physical units through the differentiation.- Parameters:
fun (
Callable) – A Python callable that computes a scalar loss given arguments. The output must be a scalar (possibly with physical units).argnums (
int|Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is0.has_aux (
bool) – IfTrue,funis expected to return a pair(loss, aux)where onlylossis differentiated. Default isFalse.holomorphic (
bool) – Whether to use holomorphic differentiation (for complex-valued functions). Default isFalse.allow_int (
bool) – Whether to allow differentiation with respect to integer-valued inputs. Default isFalse.
- Returns:
value_and_grad_fun – A function with the same signature as
funthat returns a(value, gradient)pair. Ifhas_aux=True, it returns((value, aux), gradient)instead. Gradients carry the correct physical units derived from the output and input units.- Return type:
Examples
Compute the value and gradient of a scalar function with units:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 >>> vg = suauto.value_and_grad(f) >>> value, grad = vg(jnp.array(3.0) * u.ms) >>> value 9.0 * ms ** 2 >>> grad 6.0 * ms
Differentiate with respect to multiple arguments:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def g(x, y): ... return x * y >>> vg = suauto.value_and_grad(g, argnums=(0, 1)) >>> val, grads = vg(jnp.array(3.0) * u.ms, jnp.array(4.0) * u.mV) >>> grads[0] 4.0 * mvolt >>> grads[1] 3.0 * msecond