saiunit.autograd.hessian

Contents

saiunit.autograd.hessian#

saiunit.autograd.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[source]#

Physical unit-aware Hessian of fun as a dense array.

This is the unit-aware counterpart of jax.hessian. It computes the Hessian (matrix of second derivatives) while correctly propagating physical units. Internally it is implemented as jacfwd(jacrev(fun)).

Parameters:
  • fun (Callable) – Function whose Hessian is to be computed. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers thereof (possibly carrying physical units). It should return a scalar output.

  • argnums (int | Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is 0.

  • has_aux (bool) – If True, fun is expected to return (output, aux) where only output is differentiated. Default is False.

  • holomorphic (bool) – Whether fun is promised to be holomorphic. Default is False.

Returns:

hess_fun – A function with the same arguments as fun that evaluates the Hessian. If has_aux=True, it returns (hessian, aux). Each Hessian leaf carries the correct physical units (output unit / input_i unit / input_j unit).

Return type:

Callable

Notes

hessian generalises to nested Python containers (pytrees). The tree structure of hessian(fun)(x) is formed by taking a tree product of the structure of fun(x) with two copies of the structure of x.

See also

jacrev

Reverse-mode Jacobian computation.

jacfwd

Forward-mode Jacobian computation.

Examples

Hessian of a unitless quadratic function:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def f(x):
...     return x ** 2 + 3 * x * u.ms + 2 * u.msecond2
>>> hess_fn = suauto.hessian(f)
>>> hess_fn(jnp.array(1.0) * u.ms)
[2]

Hessian of a cubic function where the result carries units:

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.autograd as suauto
>>> def g(x):
...     return x ** 3 + 3 * x * u.msecond2 + 2 * u.msecond3
>>> hess_fn = suauto.hessian(g)
>>> hess_fn(jnp.array(1.0) * u.ms)
[6] * ms