Source code for saiunit.autograd._value_and_grad

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from functools import wraps
from typing import Any, Sequence, Callable

import jax

from saiunit._base_getters import get_mantissa, get_unit, maybe_decimal
from saiunit._base_quantity import Quantity
from saiunit._compatible_import import concrete_or_error
from saiunit._misc import maybe_custom_array_tree
from ._misc import _ensure_index

__all__ = [
    'value_and_grad',
    'grad',
]


[docs] def value_and_grad( fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, ) -> Callable[..., tuple[Any, Any]]: """ Physical unit-aware version of `jax.value_and_grad <https://jax.readthedocs.io/en/latest/_autosummary/jax.value_and_grad.html>`_. Computes both the value and gradient of ``fun`` while correctly propagating physical units through the differentiation. Parameters ---------- fun : callable A Python callable that computes a scalar loss given arguments. The output must be a scalar (possibly with physical units). argnums : int or tuple of int, optional Specifies which positional argument(s) to differentiate with respect to. Default is ``0``. has_aux : bool, optional If ``True``, ``fun`` is expected to return a pair ``(loss, aux)`` where only ``loss`` is differentiated. Default is ``False``. holomorphic : bool, optional Whether to use holomorphic differentiation (for complex-valued functions). Default is ``False``. allow_int : bool, optional Whether to allow differentiation with respect to integer-valued inputs. Default is ``False``. Returns ------- value_and_grad_fun : callable A function with the same signature as ``fun`` that returns a ``(value, gradient)`` pair. If ``has_aux=True``, it returns ``((value, aux), gradient)`` instead. Gradients carry the correct physical units derived from the output and input units. Examples -------- Compute the value and gradient of a scalar function with units: .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 >>> vg = suauto.value_and_grad(f) >>> value, grad = vg(jnp.array(3.0) * u.ms) >>> value 9.0 * ms ** 2 >>> grad 6.0 * ms Differentiate with respect to multiple arguments: .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def g(x, y): ... return x * y >>> vg = suauto.value_and_grad(g, argnums=(0, 1)) >>> val, grads = vg(jnp.array(3.0) * u.ms, jnp.array(4.0) * u.mV) >>> grads[0] 4.0 * mvolt >>> grads[1] 3.0 * msecond """ argnums = concrete_or_error(_ensure_index, argnums) def fun_return_unitless_loss(*args, **kwargs): if has_aux: loss, aux = fun(*args, **kwargs) else: loss = fun(*args, **kwargs) aux = None return get_mantissa(loss), (loss, aux) fun_transformed = jax.value_and_grad( fun_return_unitless_loss, argnums=argnums, has_aux=True, holomorphic=holomorphic, allow_int=allow_int, ) @wraps(fun) def value_and_grad_fun(*args, **kwargs): args, kwargs = maybe_custom_array_tree((args, kwargs)) # autograd as usual ((_, (loss, auxiliary_data)), gradient) = fun_transformed(*args, **kwargs) # gradient Quantity conversion args_to_grad = jax.tree.map(lambda i: args[i], argnums) loss_unit = get_unit(loss) gradient = jax.tree.map( lambda arg, grads: maybe_decimal( Quantity(get_mantissa(grads), unit=loss_unit / get_unit(arg)) ), args_to_grad, gradient, is_leaf=lambda x: isinstance(x, Quantity) ) # return if has_aux: return (loss, auxiliary_data), gradient else: return loss, gradient return value_and_grad_fun
[docs] def grad( fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, ) -> Callable: """ Physical unit-aware version of `jax.grad <https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html>`_. Computes the gradient of ``fun`` while correctly propagating physical units through the differentiation. Parameters ---------- fun : callable A Python callable that computes a scalar loss given arguments. The output must be a scalar (possibly with physical units). argnums : int or tuple of int, optional Specifies which positional argument(s) to differentiate with respect to. Default is ``0``. has_aux : bool, optional If ``True``, ``fun`` is expected to return a pair ``(loss, aux)`` where only ``loss`` is differentiated. The returned function produces ``(gradient, aux)``. Default is ``False``. holomorphic : bool, optional Whether to use holomorphic differentiation (for complex-valued functions). Default is ``False``. allow_int : bool, optional Whether to allow differentiation with respect to integer-valued inputs. Default is ``False``. Returns ------- grad_fun : callable A function with the same signature as ``fun`` that returns the gradient. If ``has_aux=True``, it returns ``(gradient, aux)`` instead. Gradients carry the correct physical units derived from the output and input units. Examples -------- Compute the gradient of a scalar function with units: .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f(x): ... return x ** 2 >>> grad_fn = suauto.grad(f) >>> grad_fn(jnp.array(3.0) * u.ms) 6.0 * ms Gradient with auxiliary data: .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.autograd as suauto >>> def f_aux(x): ... return x ** 2, x * 3 >>> grad_fn = suauto.grad(f_aux, has_aux=True) >>> g, aux = grad_fn(jnp.array(3.0) * u.mV) >>> g 6.0 * mvolt >>> aux 9.0 * mvolt """ value_and_grad_f = value_and_grad( fun, argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int ) @wraps(fun) def grad_f(*args, **kwargs): _, g = value_and_grad_f(*args, **kwargs) return g @wraps(fun) def grad_f_aux(*args, **kwargs): (_, aux), g = value_and_grad_f(*args, **kwargs) return g, aux return grad_f_aux if has_aux else grad_f