brainunit.autograd.jacrev

Contents

brainunit.autograd.jacrev#

brainunit.autograd.jacrev(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#

Physical unit-aware reverse-mode Jacobian of fun.

This is the unit-aware counterpart of jax.jacrev. It computes the Jacobian matrix via reverse-mode automatic differentiation while correctly propagating physical units.

Parameters:
  • fun (Callable) – Function whose Jacobian is to be computed. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers thereof (possibly carrying physical units).

  • argnums (int | Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is 0.

  • has_aux (bool) – If True, fun is expected to return (output, aux) where only output is differentiated. Default is False.

  • holomorphic (bool) – Whether fun is promised to be holomorphic. Default is False.

  • allow_int (bool) – Whether integer-valued inputs are allowed. Default is False.

Returns:

jacfun – A function with the same signature as fun that returns the Jacobian computed via reverse-mode AD. If has_aux=True, it returns (jacobian, aux). Each Jacobian leaf carries the correct physical units (output unit / input unit).

Return type:

Callable

Notes

jacrev generalises the standard Jacobian to nested Python containers (pytrees). The tree structure of jacrev(fun)(x) is formed by taking a tree product of the structure of fun(x) with the structure of x.

Examples

Jacobian of a scalar-to-scalar function with units:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def f(x):
...     return x ** 2
>>> jac_fn = suauto.jacrev(f)
>>> jac_fn(jnp.array(3.0) * u.ms)
6.0 * ms

Jacobian with multiple arguments:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def g(x, y):
...     return x * y
>>> jac_fn = suauto.jacrev(g, argnums=(0, 1))
>>> x = jnp.array([3.0, 4.0]) * u.ohm
>>> y = jnp.array([5.0, 6.0]) * u.mA
>>> jac_x, jac_y = jac_fn(x, y)