brainunit.autograd.value_and_grad

Contents

brainunit.autograd.value_and_grad#

brainunit.autograd.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#

Physical unit-aware version of jax.value_and_grad.

Computes both the value and gradient of fun while correctly propagating physical units through the differentiation.

Parameters:
  • fun (Callable) – A Python callable that computes a scalar loss given arguments. The output must be a scalar (possibly with physical units).

  • argnums (int | Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is 0.

  • has_aux (bool) – If True, fun is expected to return a pair (loss, aux) where only loss is differentiated. Default is False.

  • holomorphic (bool) – Whether to use holomorphic differentiation (for complex-valued functions). Default is False.

  • allow_int (bool) – Whether to allow differentiation with respect to integer-valued inputs. Default is False.

Returns:

value_and_grad_fun – A function with the same signature as fun that returns a (value, gradient) pair. If has_aux=True, it returns ((value, aux), gradient) instead. Gradients carry the correct physical units derived from the output and input units.

Return type:

Callable[..., tuple[Any, Any]]

Examples

Compute the value and gradient of a scalar function with units:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def f(x):
...     return x ** 2
>>> vg = suauto.value_and_grad(f)
>>> value, grad = vg(jnp.array(3.0) * u.ms)
>>> value
9.0 * ms ** 2
>>> grad
6.0 * ms

Differentiate with respect to multiple arguments:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def g(x, y):
...     return x * y
>>> vg = suauto.value_and_grad(g, argnums=(0, 1))
>>> val, grads = vg(jnp.array(3.0) * u.ms, jnp.array(4.0) * u.mV)
>>> grads[0]
4.0 * mvolt
>>> grads[1]
3.0 * msecond