saiunit.autograd.jacfwd

Contents

saiunit.autograd.jacfwd#

saiunit.autograd.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

Physical unit-aware forward-mode Jacobian of fun.

This is the unit-aware counterpart of jax.jacfwd. It computes the Jacobian matrix via forward-mode automatic differentiation while correctly propagating physical units.

Parameters:
  • fun (Callable) – Function whose Jacobian is to be computed. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers thereof (possibly carrying physical units).

  • argnums (int | Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is 0.

  • has_aux (bool) – If True, fun is expected to return (output, aux) where only output is differentiated. Default is False.

  • holomorphic (bool) – Whether fun is promised to be holomorphic. Default is False.

Returns:

jacfun – A function with the same signature as fun that returns the Jacobian computed via forward-mode AD. If has_aux=True, it returns (jacobian, aux). Each Jacobian leaf carries the correct physical units (output unit / input unit).

Return type:

Callable

Notes

Forward-mode (jacfwd) is more efficient than reverse-mode (jacrev) when the number of inputs is smaller than the number of outputs.

jacfwd generalises the standard Jacobian to nested Python containers (pytrees). The tree structure of jacfwd(fun)(x) is formed by taking a tree product of the structure of fun(x) with the structure of x.

See also

jacrev

Reverse-mode Jacobian computation.

jacobian

Alias of jacrev.

Examples

Jacobian of a scalar-to-scalar function with units:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def f(x):
...     return x ** 2
>>> jac_fn = suauto.jacfwd(f)
>>> jac_fn(jnp.array(3.0) * u.ms)
6.0 * ms

Jacobian with multiple arguments:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def g(x, y):
...     return x * y
>>> jac_fn = suauto.jacfwd(g, argnums=(0, 1))
>>> x = jnp.array([3.0, 4.0]) * u.ohm
>>> y = jnp.array([5.0, 6.0]) * u.mA
>>> jac_x, jac_y = jac_fn(x, y)