Source code for saiunit.autograd._hessian

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from typing import Sequence, Callable

from ._jacobian import jacrev, jacfwd
from ._misc import _check_callable

__all__ = [
    'hessian'
]


[docs] def hessian( fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False ) -> Callable: """ Physical unit-aware Hessian of ``fun`` as a dense array. This is the unit-aware counterpart of `jax.hessian <https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html>`_. It computes the Hessian (matrix of second derivatives) while correctly propagating physical units. Internally it is implemented as ``jacfwd(jacrev(fun))``. Parameters ---------- fun : callable Function whose Hessian is to be computed. Its arguments at positions specified by ``argnums`` should be arrays, scalars, or standard Python containers thereof (possibly carrying physical units). It should return a scalar output. argnums : int or tuple of int, optional Specifies which positional argument(s) to differentiate with respect to. Default is ``0``. has_aux : bool, optional If ``True``, ``fun`` is expected to return ``(output, aux)`` where only ``output`` is differentiated. Default is ``False``. holomorphic : bool, optional Whether ``fun`` is promised to be holomorphic. Default is ``False``. Returns ------- hess_fun : callable A function with the same arguments as ``fun`` that evaluates the Hessian. If ``has_aux=True``, it returns ``(hessian, aux)``. Each Hessian leaf carries the correct physical units (output unit / input_i unit / input_j unit). Notes ----- ``hessian`` generalises to nested Python containers (pytrees). The tree structure of ``hessian(fun)(x)`` is formed by taking a tree product of the structure of ``fun(x)`` with two copies of the structure of ``x``. See Also -------- jacrev : Reverse-mode Jacobian computation. jacfwd : Forward-mode Jacobian computation. Examples -------- Hessian of a unitless quadratic function: .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 + 3 * x * u.ms + 2 * u.msecond2 >>> hess_fn = suauto.hessian(f) >>> hess_fn(jnp.array(1.0) * u.ms) [2] Hessian of a cubic function where the result carries units: .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def g(x): ... return x ** 3 + 3 * x * u.msecond2 + 2 * u.msecond3 >>> hess_fn = suauto.hessian(g) >>> hess_fn(jnp.array(1.0) * u.ms) [6] * ms """ _check_callable(fun) return jacfwd( jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic), argnums, has_aux=has_aux, holomorphic=holomorphic, )