brainunit.autograd.hessian#
- brainunit.autograd.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[source]#
Physical unit-aware Hessian of
funas a dense array.This is the unit-aware counterpart of jax.hessian. It computes the Hessian (matrix of second derivatives) while correctly propagating physical units. Internally it is implemented as
jacfwd(jacrev(fun)).- Parameters:
fun (
Callable) – Function whose Hessian is to be computed. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers thereof (possibly carrying physical units). It should return a scalar output.argnums (
int|Sequence[int]) – Specifies which positional argument(s) to differentiate with respect to. Default is0.has_aux (
bool) – IfTrue,funis expected to return(output, aux)where onlyoutputis differentiated. Default isFalse.holomorphic (
bool) – Whetherfunis promised to be holomorphic. Default isFalse.
- Returns:
hess_fun – A function with the same arguments as
funthat evaluates the Hessian. Ifhas_aux=True, it returns(hessian, aux). Each Hessian leaf carries the correct physical units (output unit / input_i unit / input_j unit).- Return type:
Notes
hessiangeneralises to nested Python containers (pytrees). The tree structure ofhessian(fun)(x)is formed by taking a tree product of the structure offun(x)with two copies of the structure ofx.Examples
Hessian of a unitless quadratic function:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 + 3 * x * u.ms + 2 * u.msecond2 >>> hess_fn = suauto.hessian(f) >>> hess_fn(jnp.array(1.0) * u.ms) [2]
Hessian of a cubic function where the result carries units:
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def g(x): ... return x ** 3 + 3 * x * u.msecond2 + 2 * u.msecond3 >>> hess_fn = suauto.hessian(g) >>> hess_fn(jnp.array(1.0) * u.ms) [6] * ms