Source code for saiunit.math._misc

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

from collections.abc import Sequence
from typing import (Union, TypeVar, Any)

import jax
import jax.numpy as jnp
import numpy as np

from saiunit._base_unit import Unit
from saiunit._base_getters import get_unit, is_unitless
from saiunit._base_quantity import Quantity
from saiunit._misc import set_module_as, maybe_custom_array_tree, maybe_custom_array

T = TypeVar("T")

__all__ = [
    'bool_',
    'uint2',
    'uint4',
    'uint8',
    'uint16',
    'uint32',
    'uint64',
    'int2',
    'int4',
    'int8',
    'int16',
    'int32',
    'int64',
    'bfloat16',
    'float16',
    'float32',
    'float64',
    'complex64',
    'complex128',
    'int_',
    'uint',
    'float_',
    'complex_',
    'single',
    'double',
    'csingle',
    'cdouble',

    # constants
    'e', 'pi', 'inf', 'nan', 'euler_gamma', 'inexact',

    # data types
    'dtype', 'finfo', 'iinfo', 'newaxis',

    # getting attribute funcs
    'is_quantity', 'issubdtype', 'result_type',
    'ndim', 'isreal', 'isscalar', 'isfinite', 'isinf',
    'isnan', 'shape', 'size', 'get_dtype',
    'is_float', 'is_int', 'broadcast_shapes',

    # more
    'gradient',

    # window funcs
    'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser',
]

bool_ = jnp.bool_
uint2 = jnp.uint2
uint4 = jnp.uint4
uint8 = jnp.uint8
uint16 = jnp.uint16
uint32 = jnp.uint32
uint64 = jnp.uint64
int2 = jnp.int2
int4 = jnp.int4
int8 = jnp.int8
int16 = jnp.int16
int32 = jnp.int32
int64 = jnp.int64
bfloat16 = jnp.bfloat16
float16 = jnp.float16
float32 = single = jnp.float32
float64 = double = jnp.float64
complex64 = csingle = jnp.complex64
complex128 = cdouble = jnp.complex128
int_ = jnp.int_
uint = jnp.uint
float_ = jnp.float_
complex_ = jnp.complex_


def _removechars(s, chars):
    return s.translate(str.maketrans(dict.fromkeys(chars)))


# constants
# ---------
e = np.e
pi = np.pi
inf = np.inf
nan = np.nan
inexact = jnp.inexact
euler_gamma = np.euler_gamma

# data types
# ----------
dtype = jnp.dtype
newaxis = jnp.newaxis


[docs] def is_quantity(x: Any) -> bool: """Check whether *x* is a ``Quantity`` instance. Parameters ---------- x : Any The object to test. Returns ------- out : bool ``True`` if *x* is a ``Quantity``, ``False`` otherwise. Examples -------- .. code-block:: python >>> import saiunit as u >>> u.math.is_quantity(u.Quantity(1.0, unit=u.meter)) True >>> u.math.is_quantity(1.0) False """ x = maybe_custom_array(x) return isinstance(x, Quantity)
@set_module_as('saiunit.math') def issubdtype(a: T, b: T) -> bool: """Check if a dtype is a sub-dtype of another in the type hierarchy. Parameters ---------- a : dtype First dtype to check. b : dtype Second dtype (abstract type class or concrete dtype). Returns ------- out : bool ``True`` if *a* is lower or equal in the type hierarchy to *b*. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.issubdtype(jnp.float32, jnp.floating) True >>> sumath.issubdtype(jnp.int32, jnp.floating) False """ return jnp.issubdtype(a, b) @set_module_as('saiunit.math') def result_type(*args): """Determine the result dtype from a set of input arrays or dtypes. Parameters ---------- *args : array_like or dtype Input arrays or dtypes. Returns ------- out : dtype The result dtype that would arise from operating on the inputs. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.result_type(jnp.float32, jnp.int32) dtype('float32') """ args = maybe_custom_array_tree(args) return jnp.result_type(*jax.tree.leaves(args)) @set_module_as('saiunit.math') def ndim(a: Union[Quantity, jax.typing.ArrayLike]) -> int: """Return the number of dimensions of an array or ``Quantity``. Parameters ---------- a : array_like or Quantity Input array. Returns ------- out : int Number of dimensions. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.ndim(jnp.zeros((2, 3))) 2 >>> import saiunit as u >>> sumath.ndim(u.Quantity(jnp.zeros((2, 3, 4)), unit=u.meter)) 3 """ a = maybe_custom_array(a) if isinstance(a, Quantity): return a.ndim else: return jnp.ndim(a) @set_module_as('saiunit.math') def isreal(a: Union[Quantity, jax.typing.ArrayLike]) -> jax.Array: """Test element-wise whether each element is real (has zero imaginary part). Parameters ---------- a : array_like or Quantity Input array. Returns ------- out : jax.Array Boolean array of the same shape as *a*. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.isreal(jnp.array([1.0, 2.0 + 0j, 3.0 + 1j])) Array([ True, True, False], dtype=bool) """ a = maybe_custom_array(a) if isinstance(a, Quantity): return a.isreal else: return jnp.isreal(a) @set_module_as('saiunit.math') def isscalar(a: Union[Quantity, jax.typing.ArrayLike]) -> bool: """Return ``True`` if the input is a scalar (zero-dimensional). Parameters ---------- a : array_like or Quantity Input value. Returns ------- out : bool ``True`` if *a* is a scalar. Examples -------- .. code-block:: python >>> import saiunit.math as sumath >>> sumath.isscalar(3.14) True >>> import jax.numpy as jnp >>> sumath.isscalar(jnp.array([1, 2])) False """ a = maybe_custom_array(a) if isinstance(a, Quantity): return a.isscalar else: return jnp.isscalar(a) @set_module_as('saiunit.math') def isfinite(a: Union[Quantity, jax.typing.ArrayLike]) -> jax.Array: """Test element-wise for finiteness (not inf and not NaN). Parameters ---------- a : array_like or Quantity Input array. Returns ------- out : jax.Array Boolean array of the same shape as *a*. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.isfinite(jnp.array([1.0, jnp.inf, jnp.nan])) Array([ True, False, False], dtype=bool) """ a = maybe_custom_array(a) if isinstance(a, Quantity): return a.isfinite else: return jnp.isfinite(a) @set_module_as('saiunit.math') def isinf(a: Union[Quantity, jax.typing.ArrayLike]) -> jax.Array: """Test element-wise for positive or negative infinity. Parameters ---------- a : array_like or Quantity Input array. Returns ------- out : jax.Array Boolean array of the same shape as *a*. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.isinf(jnp.array([1.0, jnp.inf, -jnp.inf])) Array([False, True, True], dtype=bool) """ a = maybe_custom_array(a) if isinstance(a, Quantity): return a.isinf else: return jnp.isinf(a) @set_module_as('saiunit.math') def isnan(a: Union[Quantity, jax.typing.ArrayLike]) -> jax.Array: """Test element-wise for NaN. Parameters ---------- a : array_like or Quantity Input array. Returns ------- out : jax.Array Boolean array of the same shape as *a*. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.isnan(jnp.array([1.0, jnp.nan, 3.0])) Array([False, True, False], dtype=bool) """ a = maybe_custom_array(a) if isinstance(a, Quantity): return a.isnan else: return jnp.isnan(a) @set_module_as('saiunit.math') def shape(a: Union[Quantity, jax.typing.ArrayLike]) -> tuple[int, ...]: """ Return the shape of an array. Parameters ---------- a : array_like Input array. Returns ------- shape : tuple of ints The elements of the shape tuple give the lengths of the corresponding array dimensions. See Also -------- len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with ``N>=1``. ndarray.shape : Equivalent array method. Examples -------- >>> saiunit.math.shape(saiunit.math.eye(3)) (3, 3) >>> saiunit.math.shape([[1, 3]]) (1, 2) >>> saiunit.math.shape([0]) (1,) >>> saiunit.math.shape(0) () """ a = maybe_custom_array(a) if isinstance(a, (Quantity, jax.Array, np.ndarray)): return a.shape else: return np.shape(a) @set_module_as('saiunit.math') def size(a: Union[Quantity, jax.typing.ArrayLike], axis: int = None) -> int: """ Return the number of elements along a given axis. Parameters ---------- a : array_like Input data. axis : int, optional Axis along which the elements are counted. By default, give the total number of elements. Returns ------- element_count : int Number of elements along the specified axis. See Also -------- shape : dimensions of array Array.shape : dimensions of array Array.size : number of elements in array Examples -------- >>> a = Quantity([[1,2,3], [4,5,6]]) >>> saiunit.math.size(a) 6 >>> saiunit.math.size(a, 1) 3 >>> saiunit.math.size(a, 0) 2 """ a = maybe_custom_array(a) if isinstance(a, (Quantity, jax.Array, np.ndarray)): if axis is None: return a.size else: return a.shape[axis] else: return np.size(a, axis=axis) @set_module_as('saiunit.math') def finfo(a: Union[Quantity, jax.typing.ArrayLike]) -> jnp.finfo: a = maybe_custom_array(a) if isinstance(a, Quantity): return jnp.finfo(a.mantissa) else: return jnp.finfo(a) @set_module_as('saiunit.math') def iinfo(a: Union[Quantity, jax.typing.ArrayLike]) -> jnp.iinfo: a = maybe_custom_array(a) if isinstance(a, Quantity): return jnp.iinfo(a.mantissa) else: return jnp.iinfo(a) @set_module_as('saiunit.math') def broadcast_shapes(*shapes): """Broadcast a sequence of array shapes. Parameters ---------- *shapes : tuple of int The shapes of the arrays to broadcast. Returns ------- broadcast_shape : tuple of int The broadcasted shape. Examples -------- .. code-block:: python >>> import saiunit.math as sumath >>> sumath.broadcast_shapes((2, 1), (1, 3)) (2, 3) """ return jnp.broadcast_shapes(*shapes) environ = None # type: ignore[assignment] @set_module_as('brainstate.math') def get_dtype(a): """Get the dtype of an array, ``Quantity``, or Python scalar. Parameters ---------- a : array_like, Quantity, or scalar The input whose dtype is to be determined. Returns ------- out : dtype The data type of *a*. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.get_dtype(jnp.array([1.0, 2.0])) dtype('float32') """ a = maybe_custom_array(a) if hasattr(a, 'dtype'): return a.dtype else: global environ if isinstance(a, bool): return bool elif isinstance(a, int): if environ is None: from brainstate import environ return environ.ditype() elif isinstance(a, float): if environ is None: from brainstate import environ return environ.dftype() elif isinstance(a, complex): if environ is None: from brainstate import environ return environ.dctype() else: raise ValueError(f'Can not get dtype of {a}.') @set_module_as('brainstate.math') def is_float(array): """Check if the array has a floating-point dtype. Parameters ---------- array : array_like or Quantity The input array. Returns ------- out : bool ``True`` if the array dtype is a floating-point type. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.is_float(jnp.array([1.0])) True >>> sumath.is_float(jnp.array([1])) False """ array = maybe_custom_array(array) return jnp.issubdtype(get_dtype(array), jnp.floating) @set_module_as('brainstate.math') def is_int(array): """Check if the array has an integer dtype. Parameters ---------- array : array_like or Quantity The input array. Returns ------- out : bool ``True`` if the array dtype is an integer type. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> sumath.is_int(jnp.array([1])) True >>> sumath.is_int(jnp.array([1.0])) False """ array = maybe_custom_array(array) return jnp.issubdtype(get_dtype(array), jnp.integer) @set_module_as('saiunit.math') def gradient( f: Union[jax.typing.ArrayLike, Quantity], *varargs: Union[jax.typing.ArrayLike, Quantity], axis: Union[int, Sequence[int], None] = None, edge_order: Union[int, None] = None, ) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: """ Computes the gradient of a scalar field. Return the gradient of an N-dimensional array. The gradient is computed using second order accurate central differences in the interior points and either first or second order accurate one-sides (forward or backwards) differences at the boundaries. The returned gradient hence has the same shape as the input array. Parameters ---------- f : array_like, Quantity An N-dimensional array containing samples of a scalar function. varargs : list of scalar or array, optional Spacing between f values. Default unitary spacing for all dimensions. Spacing can be specified using: 1. single scalar to specify a sample distance for all dimensions. 2. N scalars to specify a constant sample distance for each dimension. i.e. `dx`, `dy`, `dz`, ... 3. N arrays to specify the coordinates of the values along each dimension of F. The length of the array must match the size of the corresponding dimension 4. Any combination of N scalars/arrays with the meaning of 2. and 3. If `axis` is given, the number of varargs must equal the number of axes. Default: 1. edge_order : {1, 2}, optional Gradient is calculated using N-th order accurate differences at the boundaries. Default: 1. axis : None or int or tuple of ints, optional Gradient is calculated only along the given axis or axes The default (axis = None) is to calculate the gradient for all the axes of the input array. axis may be negative, in which case it counts from the last to the first axis. Returns ------- gradient : ndarray or list of ndarray or Quantity A list of ndarrays (or a single ndarray if there is only one dimension) corresponding to the derivatives of f with respect to each dimension. Each derivative has the same shape as f. Examples -------- .. code-block:: python >>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> f = jnp.array([1., 2., 4., 7., 11.]) >>> sumath.gradient(f) Array([1. , 1.5, 2.5, 3.5, 4. ], dtype=float32) """ f, varargs = maybe_custom_array_tree((f, varargs)) if edge_order is not None: raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") if len(varargs) == 0: if isinstance(f, Quantity) and not is_unitless(f): return Quantity(jnp.gradient(f.mantissa, axis=axis), unit=f.unit) else: return jnp.gradient(f) elif len(varargs) == 1: unit = get_unit(f) / get_unit(varargs[0]) if isinstance(unit, Unit) and unit.is_unitless: return jnp.gradient(f, varargs[0], axis=axis) else: return [Quantity(r, unit=unit) for r in jnp.gradient(f.mantissa, Quantity(varargs[0]).mantissa, axis=axis)] else: unit_list = [get_unit(f) / get_unit(v) for v in varargs] f = f.mantissa if isinstance(f, Quantity) else f varargs = [v.mantissa if isinstance(v, Quantity) else v for v in varargs] result_list = jnp.gradient(f, *varargs, axis=axis) return [(Quantity(r, unit=unit) if unit is not None else r) for r, unit in zip(result_list, unit_list)] # window funcs # ------------ bartlett = jnp.bartlett blackman = jnp.blackman hamming = jnp.hamming hanning = jnp.hanning kaiser = jnp.kaiser