brainunit.autograd.jacrev#
- brainunit.autograd.jacrev(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]#
Physical unit-aware reverse-mode Jacobian of
fun.This is the unit-aware counterpart of jax.jacrev. It computes the Jacobian matrix via reverse-mode automatic differentiation while correctly propagating physical units.
- Parameters:
fun (
Callable) – Function whose Jacobian is to be computed. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers thereof (possibly carrying physical units).argnums (
int|Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is0.has_aux (
bool) – IfTrue,funis expected to return(output, aux)where onlyoutputis differentiated. Default isFalse.holomorphic (
bool) – Whetherfunis promised to be holomorphic. Default isFalse.allow_int (
bool) – Whether integer-valued inputs are allowed. Default isFalse.
- Returns:
jacfun – A function with the same signature as
funthat returns the Jacobian computed via reverse-mode AD. Ifhas_aux=True, it returns(jacobian, aux). Each Jacobian leaf carries the correct physical units (output unit / input unit).- Return type:
Notes
jacrevgeneralises the standard Jacobian to nested Python containers (pytrees). The tree structure ofjacrev(fun)(x)is formed by taking a tree product of the structure offun(x)with the structure ofx.Examples
Jacobian of a scalar-to-scalar function with units:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 >>> jac_fn = suauto.jacrev(f) >>> jac_fn(jnp.array(3.0) * u.ms) 6.0 * ms
Jacobian with multiple arguments:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def g(x, y): ... return x * y >>> jac_fn = suauto.jacrev(g, argnums=(0, 1)) >>> x = jnp.array([3.0, 4.0]) * u.ohm >>> y = jnp.array([5.0, 6.0]) * u.mA >>> jac_x, jac_y = jac_fn(x, y)