brainunit.autograd.grad

Contents

brainunit.autograd.grad#

brainunit.autograd.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#

Physical unit-aware version of jax.grad.

Computes the gradient of fun while correctly propagating physical units through the differentiation.

Parameters:
  • fun (Callable) – A Python callable that computes a scalar loss given arguments. The output must be a scalar (possibly with physical units).

  • argnums (int | Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is 0.

  • has_aux (bool) – If True, fun is expected to return a pair (loss, aux) where only loss is differentiated. The returned function produces (gradient, aux). Default is False.

  • holomorphic (bool) – Whether to use holomorphic differentiation (for complex-valued functions). Default is False.

  • allow_int (bool) – Whether to allow differentiation with respect to integer-valued inputs. Default is False.

Returns:

grad_fun – A function with the same signature as fun that returns the gradient. If has_aux=True, it returns (gradient, aux) instead. Gradients carry the correct physical units derived from the output and input units.

Return type:

Callable

Examples

Compute the gradient of a scalar function with units:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def f(x):
...     return x ** 2
>>> grad_fn = suauto.grad(f)
>>> grad_fn(jnp.array(3.0) * u.ms)
6.0 * ms

Gradient with auxiliary data:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def f_aux(x):
...     return x ** 2, x * 3
>>> grad_fn = suauto.grad(f_aux, has_aux=True)
>>> g, aux = grad_fn(jnp.array(3.0) * u.mV)
>>> g
6.0 * mvolt
>>> aux
9.0 * mvolt