# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import functools
from typing import (Union, Sequence, Tuple, Optional)
import jax
import jax.numpy as jnp
import numpy as np
from jax._src.numpy.util import promote_dtypes as _promote_dtypes
from saiunit._base_unit import UNITLESS
from saiunit._base_getters import (
fail_for_dimension_mismatch,
get_unit,
maybe_decimal,
split_mantissa_unit,
unit_scale_align_to_first,
)
from saiunit._base_quantity import Quantity
from saiunit._misc import set_module_as, maybe_custom_array, maybe_custom_array_tree
from ._fun_array_creation import asarray
__all__ = [
# sequence inputs
'row_stack', 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', 'block', 'append',
# sequence outputs
'split', 'array_split', 'dsplit', 'hsplit', 'vsplit',
# broadcasting arrays
'atleast_1d', 'atleast_2d', 'atleast_3d', 'broadcast_arrays', 'broadcast_to',
# array manipulation
'reshape', 'moveaxis', 'transpose', 'swapaxes', 'tile', 'repeat',
'flip', 'fliplr', 'flipud', 'roll', 'expand_dims', 'squeeze',
'sort', 'max', 'min', 'amax', 'amin', 'diagflat', 'diagonal', 'choose', 'ravel',
'flatten', 'unflatten', 'remove_diag',
# math funcs keep unit (unary)
'astype', 'real', 'imag', 'conj', 'conjugate', 'negative', 'positive',
'abs', 'sum', 'nancumsum', 'nansum',
'cumsum', 'ediff1d', 'absolute', 'fabs', 'median',
'nanmin', 'nanmax', 'ptp', 'average', 'mean', 'std',
'nanmedian', 'nanmean', 'nanstd', 'diff', 'rot90', 'intersect1d', 'nan_to_num',
'percentile', 'nanpercentile', 'quantile', 'nanquantile',
# math funcs only accept unitless (unary) can return Quantity
'round', 'around', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'modf',
# math funcs keep unit (binary)
'fmod', 'mod', 'copysign', 'remainder',
'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', 'trace',
'add', 'subtract', 'nextafter', 'promote_dtypes',
# math funcs keep unit
'interp', 'clip', 'histogram',
# selection
'compress', 'extract', 'take', 'select', 'where', 'unique', 'gather',
]
# -------------------------------------------------------------------
def _fun_keep_unit_sequence(
func,
*args,
**kwargs
):
args = maybe_custom_array_tree(args)
leaves, treedef = jax.tree.flatten(args, is_leaf=lambda x: isinstance(x, Quantity))
# leaves = jax.tree.map(
# lambda x: x.factorless() if isinstance(x, Quantity) else x,
# leaves,
# is_leaf=lambda x: isinstance(x, Quantity)
# )
leaves = unit_scale_align_to_first(*leaves)
unit = leaves[0].unit
leaves = [x.mantissa for x in leaves]
args = treedef.unflatten(leaves)
r = func(*args, **kwargs)
if unit.is_unitless:
return r
return Quantity(r, unit=unit)
@set_module_as('saiunit.math')
def concatenate(
arrays: Union[Sequence[jax.typing.ArrayLike], Sequence[Quantity]],
axis: Optional[int] = None,
dtype: Optional[jax.typing.DTypeLike] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Join a sequence of quantities or arrays along an existing axis.
Parameters
----------
arrays : sequence of array_like, Quantity
The arrays must have the same shape, except in the dimension corresponding
to `axis` (the first, by default).
axis : int, optional
The axis along which the arrays will be joined. Default is 0.
dtype : dtype, optional
If provided, the concatenation will be done using this dtype. Otherwise, the
array with the highest precision will be used.
Returns
-------
res : ndarray, Quantity
The concatenated array. The type of the array is the same as that of the
first array passed in.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2] * u.second
>>> b = [3, 4] * u.second
>>> u.math.concatenate([a, b])
"""
return _fun_keep_unit_sequence(jnp.concatenate, arrays, axis=axis, dtype=dtype, **kwargs)
@set_module_as('saiunit.math')
def stack(
arrays: Union[Sequence[jax.typing.ArrayLike], Sequence[Quantity]],
axis: int = 0,
dtype: Optional[jax.typing.DTypeLike] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Join a sequence of quantities or arrays along a new axis.
Parameters
----------
arrays : sequence of array_like, Quantity
The arrays must have the same shape.
axis : int, optional
The axis in the result array along which the input arrays are stacked.
dtype : dtype, optional
If provided, the concatenation will be done using this dtype. Otherwise, the
array with the highest precision will be used.
Returns
-------
res : ndarray, Quantity
The stacked array has one more dimension than the input arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.second
>>> b = [4, 5, 6] * u.second
>>> u.math.stack([a, b])
"""
return _fun_keep_unit_sequence(jnp.stack, arrays, axis=axis, dtype=dtype, **kwargs)
@set_module_as('saiunit.math')
def vstack(
tup: Union[Sequence[jax.typing.ArrayLike], Sequence[Quantity]],
dtype: Optional[jax.typing.DTypeLike] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Stack quantities or arrays in sequence vertically (row wise).
Parameters
----------
tup : sequence of array_like, Quantity
The arrays must have the same shape along all but the first axis.
dtype : dtype, optional
If provided, the concatenation will be done using this dtype. Otherwise, the
array with the highest precision will be used.
Returns
-------
res : ndarray, Quantity
The array formed by stacking the given arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> b = [4, 5, 6] * u.meter
>>> u.math.vstack([a, b])
"""
return _fun_keep_unit_sequence(jnp.vstack, tup, dtype=dtype, **kwargs)
row_stack = vstack
@set_module_as('saiunit.math')
def hstack(
arrays: Union[Sequence[jax.typing.ArrayLike], Sequence[Quantity]],
dtype: Optional[jax.typing.DTypeLike] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Stack quantities arrays in sequence horizontally (column wise).
Parameters
----------
arrays : sequence of array_like, Quantity
The arrays must have the same shape along all but the second axis.
dtype : dtype, optional
If provided, the concatenation will be done using this dtype. Otherwise, the
array with the highest precision will be used.
Returns
-------
res : ndarray, Quantity
The array formed by stacking the given arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> b = [4, 5, 6] * u.meter
>>> u.math.hstack([a, b])
"""
return _fun_keep_unit_sequence(jnp.hstack, arrays, dtype=dtype, **kwargs)
@set_module_as('saiunit.math')
def dstack(
arrays: Union[Sequence[jax.typing.ArrayLike], Sequence[Quantity]],
dtype: Optional[jax.typing.DTypeLike] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Stack quantities or arrays in sequence depth wise (along third axis).
Parameters
----------
arrays : sequence of array_like, Quantity
The arrays must have the same shape along all but the third axis.
dtype : dtype, optional
If provided, the concatenation will be done using this dtype. Otherwise, the
array with the highest precision will be used.
Returns
-------
res : ndarray, Quantity
The array formed by stacking the given arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1], [2], [3]] * u.meter
>>> b = [[4], [5], [6]] * u.meter
>>> u.math.dstack([a, b])
"""
return _fun_keep_unit_sequence(jnp.dstack, arrays, dtype=dtype, **kwargs)
@set_module_as('saiunit.math')
def column_stack(
tup: Union[Sequence[jax.typing.ArrayLike], Sequence[Quantity]],
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Stack 1-D arrays as columns into a 2-D array.
Take a sequence of 1-D arrays and stack them as columns to make a single
2-D array. 2-D arrays are stacked as-is, just like with hstack.
Parameters
----------
tup : sequence of 1-D array_like, Quantity
1-D arrays to stack as columns.
Returns
-------
res : ndarray, Quantity
The array formed by stacking the given arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.second
>>> b = [4, 5, 6] * u.second
>>> u.math.column_stack([a, b])
"""
return _fun_keep_unit_sequence(jnp.column_stack, tup, **kwargs)
@set_module_as('saiunit.math')
def block(
arrays: Sequence[Union[jax.Array, Quantity]],
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Assemble a quantity or an array from nested lists of blocks.
Parameters
----------
arrays : sequence of array_like, Quantity
Each element in `arrays` can itself be a nested sequence of arrays, in which case the blocks in the corresponding
cells are recursively stacked as the elements of the resulting array.
Returns
-------
res : ndarray, Quantity
The array constructed from the given blocks.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2], [3, 4]] * u.second
>>> u.math.block(a)
"""
return _fun_keep_unit_sequence(jnp.block, arrays, **kwargs)
@set_module_as('saiunit.math')
def append(
arr: Union[jax.Array, Quantity],
values: Union[jax.Array, Quantity],
axis: Optional[int] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Append values to the end of a quantity or an array.
Parameters
----------
arr : array_like, Quantity
Values are appended to a copy of this array.
values : array_like, Quantity
These values are appended to a copy of `arr`.
It must be of the correct shape (the same shape as `arr`, excluding `axis`).
axis : int, optional
The axis along which `values` are appended. If `axis` is None, `values` is flattened before use.
Returns
-------
res : ndarray, Quantity
A copy of `arr` with `values` appended to `axis`. Note that `append` does not occur in-place:
a new array is allocated and filled.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.second
>>> u.math.append(a, 4 * u.second)
"""
return _fun_keep_unit_sequence(jnp.append, arr, values, axis=axis, **kwargs)
def _fun_keep_unit_return_sequence(
func,
x: jax.typing.ArrayLike | Quantity,
*args,
**kwargs
):
x = maybe_custom_array(x)
args, kwargs = maybe_custom_array_tree((args, kwargs))
if isinstance(x, Quantity):
r = func(x.mantissa, *args, **kwargs)
return [maybe_decimal(Quantity(rr, unit=x.unit)) for rr in r]
return func(x, *args, **kwargs)
@set_module_as('saiunit.math')
def split(
a: Union[jax.Array, Quantity],
indices_or_sections: Union[int, Sequence[int]],
axis: int = 0,
**kwargs,
) -> Union[Sequence[jax.Array | Quantity]]:
"""
Split quantity or array into a list of multiple sub-arrays.
Parameters
----------
a : array_like, Quantity
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided into
N equal arrays along `axis`. If such a split is not possible, an error is
raised. If `indices_or_sections` is a 1-D array of sorted integers, the
entries indicate where along `axis` the array is split. For example,
`[2, 3]` would, for `axis=0`, result in
- `a[:2]`
- `a[2:3]`
- `a[3:]`
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
res : list of ndarrays, Quantity
A list of sub-arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.arange(9.0) * u.second
>>> u.math.split(a, 3)
"""
return _fun_keep_unit_return_sequence(jnp.split, a, indices_or_sections=indices_or_sections, axis=axis, **kwargs)
@set_module_as('saiunit.math')
def array_split(
ary: Union[Quantity, jax.typing.ArrayLike],
indices_or_sections: Union[int, jax.typing.ArrayLike],
axis: Optional[int] = 0,
**kwargs,
) -> Union[Sequence[Quantity | jax.Array]]:
"""
Split an array into multiple sub-arrays.
Parameters
----------
ary : Quantity or array
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, `ary` is divided into `indices_or_sections` sub-arrays along `axis`.
If such a split is not possible, an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where along `axis` the array is split.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of Quantity or list of array
A list of sub-arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.arange(9.0) * u.second
>>> u.math.array_split(a, 3)
"""
return _fun_keep_unit_return_sequence(jnp.split, ary, indices_or_sections=indices_or_sections, axis=axis, **kwargs)
@set_module_as('saiunit.math')
def dsplit(
a: Union[jax.Array, Quantity],
indices_or_sections: Union[int, Sequence[int]],
**kwargs,
) -> Union[Sequence[jax.Array | Quantity]]:
"""
Split a quantity or an array into multiple sub-arrays along the 3rd axis (depth).
Parameters
----------
a : array_like, Quantity
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided into
N equal arrays along the third axis (depth). If such a split is not possible,
an error is raised. If `indices_or_sections` is a 1-D array of sorted integers,
the entries indicate where along the third axis the array is split.
Returns
-------
res : list of ndarrays, Quantity
A list of sub-arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.arange(16.0).reshape(2, 2, 4) * u.meter
>>> u.math.dsplit(a, 2)
"""
return _fun_keep_unit_return_sequence(jnp.dsplit, a, indices_or_sections, **kwargs)
@set_module_as('saiunit.math')
def hsplit(
a: Union[jax.Array, Quantity],
indices_or_sections: Union[int, Sequence[int]],
**kwargs,
) -> Union[Sequence[jax.Array | Quantity]]:
"""
Split a quantity or an array into multiple sub-arrays horizontally (column-wise).
Parameters
----------
a : array_like, Quantity
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided into
N equal arrays along the second axis. If such a split is not possible, an
error is raised. If `indices_or_sections` is a 1-D array of sorted integers,
the entries indicate where along the second axis the array is split.
Returns
-------
res : list of ndarrays, Quantity
A list of sub-arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.arange(16.0).reshape(4, 4) * u.meter
>>> u.math.hsplit(a, 2)
"""
return _fun_keep_unit_return_sequence(jnp.hsplit, a, indices_or_sections, **kwargs)
@set_module_as('saiunit.math')
def vsplit(
a: Union[jax.Array, Quantity],
indices_or_sections: Union[int, Sequence[int]],
**kwargs,
) -> Union[Sequence[jax.Array | Quantity]]:
"""
Split a quantity or an array into multiple sub-arrays vertically (row-wise).
Parameters
----------
a : array_like, Quantity
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided into
N equal arrays along the first axis. If such a split is not possible, an
error is raised. If `indices_or_sections` is a 1-D array of sorted integers,
the entries indicate where along the first axis the array is split.
Returns
-------
res : list of ndarrays, Quantity
A list of sub-arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.arange(16.0).reshape(4, 4) * u.meter
>>> u.math.vsplit(a, 2)
"""
return _fun_keep_unit_return_sequence(jnp.vsplit, a, indices_or_sections, **kwargs)
# broadcasting arrays
# -------------------
def _broadcast_fun(func, *args, **kwargs):
args = [asarray(x) for x in args]
args, treedef = jax.tree.flatten(args)
r = func(*args, **kwargs)
r = treedef.unflatten([r] if isinstance(r, jax.Array) else r)
if len(r) == 1:
return r[0]
return r
# more
# ----
@set_module_as('saiunit.math')
def broadcast_arrays(
*args: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity | jax.Array | Sequence[jax.Array | Quantity]]:
"""
Broadcast any number of arrays against each other.
Parameters
----------
`*args` : array_likes
The arrays to broadcast.
Returns
-------
broadcasted : list of arrays
These arrays are views on the original arrays. They are typically
not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location. If you need
to write to the arrays, make copies first. While you can set the
``writable`` flag True, writing to a single output value may end up
changing more than one location in the output array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.second
>>> b = [[4], [5]] * u.second
>>> u.math.broadcast_arrays(a, b)
"""
return _broadcast_fun(jnp.broadcast_arrays, *args, **kwargs)
@set_module_as('saiunit.math')
def promote_dtypes(
*args: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity | jax.Array | Sequence[jax.Array | Quantity]]:
"""
Promote the data types of the inputs to a common type.
Parameters
----------
`*args` : array_likes
The arrays to promote.
Returns
-------
promoted : list of arrays
These arrays have the same shape as the input arrays, with the
data type of the most precise input.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.second
>>> b = [4.0, 5.0, 6.0] * u.second
>>> u.math.promote_dtypes(a, b)
"""
return _broadcast_fun(_promote_dtypes, *args, **kwargs)
@set_module_as('saiunit.math')
def broadcast_to(
array: Union[Quantity, jax.typing.ArrayLike],
shape: Tuple[int, ...],
**kwargs,
) -> Quantity | jax.Array:
"""
Broadcast an array to a new shape.
Parameters
----------
array : array_like
The array to broadcast.
shape : tuple or int
The shape of the desired array. A single integer ``i`` is interpreted
as ``(i,)``.
Returns
-------
broadcast : array
A readonly view on the original array with the given shape. It is
typically not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location.
Raises
------
ValueError
If the array is not compatible with the new shape according to NumPy's
broadcasting rules.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> u.math.broadcast_to(a, (2, 3))
"""
return _fun_keep_unit_unary(jnp.broadcast_to, array, shape=shape, **kwargs)
@set_module_as('saiunit.math')
def atleast_1d(
*arys: Union[jax.Array, Quantity],
**kwargs,
) -> Union[Quantity | jax.Array | Sequence[jax.Array | Quantity]]:
"""
View inputs as quantities or arrays with at least one dimension.
Parameters
----------
*arys : array_like, Quantity
One or more input arrays or quantities.
Returns
-------
res : ndarray, Quantity
An array or a quantity, or a tuple of arrays or quantities, each with `a.ndim >= 1`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> u.math.atleast_1d(0 * u.second)
"""
return _broadcast_fun(jnp.atleast_1d, *arys, **kwargs)
@set_module_as('saiunit.math')
def atleast_2d(
*arys: Union[jax.Array, Quantity],
**kwargs,
) -> Union[Quantity | jax.Array | Sequence[jax.Array | Quantity]]:
"""
View inputs as quantities or arrays with at least two dimensions.
Parameters
----------
*arys : array_like, Quantity
One or more input arrays or quantities.
Returns
-------
res : ndarray, Quantity
An array or a quantity, or a tuple of arrays or quantities, each with `a.ndim >= 2`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.second
>>> u.math.atleast_2d(a)
"""
return _broadcast_fun(jnp.atleast_2d, *arys, **kwargs)
@set_module_as('saiunit.math')
def atleast_3d(
*arys: Union[jax.Array, Quantity],
**kwargs,
) -> Union[Quantity | jax.Array | Sequence[jax.Array | Quantity]]:
"""
View inputs as quantities or arrays with at least three dimensions.
Parameters
----------
*arys : array_like, Quantity
One or more input arrays or quantities.
Returns
-------
res : ndarray, Quantity
An array or a quantity, or a tuple of arrays or quantities, each with `a.ndim >= 3`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2], [3, 4]] * u.meter
>>> u.math.atleast_3d(a)
"""
return _broadcast_fun(jnp.atleast_3d, *arys, **kwargs)
# array manipulation
# ------------------
@set_module_as('saiunit.math')
def reshape(
a: Union[jax.Array, Quantity],
shape: Union[int, Tuple[int, ...]],
order: str = 'C',
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Gives a new shape to a quantity or an array without changing its data.
Parameters
----------
a : array_like, Quantity
Array to be reshaped.
shape : int or tuple of ints
The new shape should be compatible with the original shape. If
an integer, then the result will be a 1-D array of that length.
One shape dimension can be -1. In this case, the value is
inferred from the length of the array and remaining dimensions.
order : {'C', 'F', 'A'}, optional
Read the elements of `a` using this index order, and place the
elements into the reshaped array using this index order. 'C'
means to read / write the elements using C-like index order,
with the last axis index changing fastest, back to the first
axis index changing slowest. 'F' means to read / write the
elements using Fortran-like index order, with the first index
changing fastest, and the last index changing slowest. Note that
the 'C' and 'F' options take no account of the memory layout of
the underlying array, and only refer to the order of indexing.
'A' means to read / write the elements in Fortran-like index
order if `a` is Fortran *contiguous* in memory, C-like order
otherwise.
Returns
-------
reshaped_array : ndarray, Quantity
This will be a new view object if possible; otherwise, it will
be a copy. Note there is no guarantee of the *memory layout* (C- or
Fortran- contiguous) of the returned array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3, 4] * u.second
>>> u.math.reshape(a, (2, 2))
"""
return _fun_keep_unit_unary(jnp.reshape, a, shape=shape, order=order, **kwargs)
@set_module_as('saiunit.math')
def moveaxis(
a: Union[jax.Array, Quantity],
source: Union[int, Tuple[int, ...]],
destination: Union[int, Tuple[int, ...]],
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Moves axes of a quantity or an array to new positions.
Other axes remain in their original order.
Parameters
----------
a : array_like, Quantity
The array whose axes should be reordered.
source : int or sequence of int
Original positions of the axes to move. These must be unique.
destination : int or sequence of int
Destination positions for each of the original axes. These must also be
unique.
Returns
-------
result : ndarray, Quantity
Array with moved axes. This array is a view of the input array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.zeros((3, 4, 5)) * u.meter
>>> u.math.moveaxis(a, 0, -1).shape
(4, 5, 3)
"""
return _fun_keep_unit_unary(jnp.moveaxis, a, source=source, destination=destination, **kwargs)
@set_module_as('saiunit.math')
def transpose(
a: Union[jax.Array, Quantity],
axes: Optional[Union[int, Tuple[int, ...]]] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Permute the dimensions of a quantity or an array.
Parameters
----------
a : array_like, Quantity
Input array.
axes : list of ints, optional
By default, reverse the dimensions, otherwise permute the axes
according to the values given.
Returns
-------
p : ndarray, Quantity
`a` with its axes permuted. A view is returned whenever
possible.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.ones((2, 3)) * u.second
>>> u.math.transpose(a).shape
(3, 2)
"""
return _fun_keep_unit_unary(jnp.transpose, a, axes=axes, **kwargs)
@set_module_as('saiunit.math')
def swapaxes(
a: Union[jax.Array, Quantity],
axis1: int,
axis2: int,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Interchange two axes of a quantity or an array.
Parameters
----------
a : array_like, Quantity
Input array.
axis1 : int
First axis.
axis2 : int
Second axis.
Returns
-------
a_swapped : ndarray, Quantity
a new array where the axes are swapped.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.zeros((3, 4, 5)) * u.meter
>>> u.math.swapaxes(a, 0, 2).shape
(5, 4, 3)
"""
return _fun_keep_unit_unary(jnp.swapaxes, a, axis1=axis1, axis2=axis2, **kwargs)
@set_module_as('saiunit.math')
def tile(
A: Union[jax.Array, Quantity],
reps: Union[int, Tuple[int, ...]],
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Construct a quantity or an array by repeating A the number of times given by reps.
Parameters
----------
A : array_like, Quantity
The input array.
reps : array_like
The number of repetitions of A along each axis.
Returns
-------
res : ndarray, Quantity
The tiled output array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> u.math.tile(a, 2)
"""
return _fun_keep_unit_unary(jnp.tile, A, reps=reps, **kwargs)
@set_module_as('saiunit.math')
def repeat(
a: Union[jax.Array, Quantity],
repeats: Union[int, Tuple[int, ...]],
axis: Optional[int] = None,
total_repeat_length: Optional[int] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Repeat elements of a quantity or an array.
Parameters
----------
a : array_like, Quantity
Input array.
repeats : int or tuple of ints
The number of repetitions for each element. `repeats` is broadcasted to fit the shape of the given axis.
axis : int, optional
The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
total_repeat_length : int, optional
The total length of the repeated array. If `total_repeat_length` is not None, the output array
will have the length of `total_repeat_length`.
Returns
-------
res : ndarray, Quantity
Output array which has the same shape as `a`, except along the given axis.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.second
>>> u.math.repeat(a, 2)
"""
return _fun_keep_unit_unary(jnp.repeat, a, repeats=repeats, axis=axis, total_repeat_length=total_repeat_length, **kwargs)
@set_module_as('saiunit.math')
def flip(
m: Union[jax.Array, Quantity],
axis: Optional[Union[int, Tuple[int, ...]]] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Reverse the order of elements in a quantity or an array along the given axis.
Parameters
----------
m : array_like, Quantity
Input array.
axis : int or tuple of ints, optional
Axis or axes along which to flip over. The default, axis=None, will flip over all of the axes of the input array.
Returns
-------
res : ndarray, Quantity
A view of `m` with the entries of axis reversed. Since a view is returned, this operation is done in constant time.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> u.math.flip(a)
"""
return _fun_keep_unit_unary(jnp.flip, m, axis=axis, **kwargs)
@set_module_as('saiunit.math')
def fliplr(
m: Union[jax.Array, Quantity],
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Flip quantity or array in the left/right direction.
Parameters
----------
m : array_like, Quantity
Input array.
Returns
-------
res : ndarray, Quantity
A view of `m` with the columns reversed. Since a view is returned, this operation is done in constant time.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2], [3, 4]] * u.meter
>>> u.math.fliplr(a)
"""
return _fun_keep_unit_unary(jnp.fliplr, m, **kwargs)
@set_module_as('saiunit.math')
def flipud(
m: Union[jax.Array, Quantity],
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Flip quantity or array in the up/down direction.
Parameters
----------
m : array_like, Quantity
Input array.
Returns
-------
res : ndarray, Quantity
A view of `m` with the rows reversed.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2], [3, 4]] * u.meter
>>> u.math.flipud(a)
"""
return _fun_keep_unit_unary(jnp.flipud, m, **kwargs)
@set_module_as('saiunit.math')
def roll(
a: Union[jax.Array, Quantity],
shift: Union[int, Tuple[int, ...]],
axis: Optional[Union[int, Tuple[int, ...]]] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Roll quantity or array elements along a given axis.
Parameters
----------
a : array_like, Quantity
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple, then `axis` must be a tuple of the same size,
and each of the given axes is shifted by the corresponding number. If an int while `axis` is a tuple of ints,
then the same value is used for all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the array is flattened before shifting, after which
the original shape is restored.
Returns
-------
res : ndarray, Quantity
Output array, with the same shape as `a`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.second
>>> u.math.roll(a, 1)
"""
return _fun_keep_unit_unary(jnp.roll, a, shift=shift, axis=axis, **kwargs)
@set_module_as('saiunit.math')
def expand_dims(
a: Union[jax.Array, Quantity],
axis: int,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Expand the shape of a quantity or an array.
Parameters
----------
a : array_like, Quantity
Input array.
axis : int
Position in the expanded axes where the new axis is placed.
Returns
-------
res : ndarray, Quantity
View of `a` with the number of dimensions increased by one.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> u.math.expand_dims(a, axis=0).shape
(1, 3)
"""
return _fun_keep_unit_unary(jnp.expand_dims, a, axis=axis, **kwargs)
@set_module_as('saiunit.math')
def squeeze(
a: Union[jax.Array, Quantity],
axis: Optional[Union[int, Tuple[int, ...]]] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Remove single-dimensional entries from the shape of a quantity or an array.
Parameters
----------
a : array_like, Quantity
Input data.
axis : None or int or tuple of ints, optional
Selects a subset of the single-dimensional entries in the shape. If an axis is selected with shape entry greater
than one, an error is raised.
Returns
-------
res : ndarray, Quantity
An array with the same data as `a`, but with a lower dimension.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[[1], [2], [3]]] * u.second
>>> u.math.squeeze(a).shape
(3,)
"""
return _fun_keep_unit_unary(jnp.squeeze, a, axis=axis, **kwargs)
@set_module_as('saiunit.math')
def sort(
a: Union[jax.Array, Quantity],
axis: Optional[int] = -1,
*,
kind: None = None,
order: None = None,
stable: bool = True,
descending: bool = False,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Return a sorted copy of a quantity or an array.
Parameters
----------
a : array_like, Quantity
Array or quantity to be sorted.
axis : int or None, optional
Axis along which to sort. If None, the array is flattened before sorting. The default is -1, which sorts along
the last axis.
kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional
Sorting algorithm. The default is 'quicksort'.
order : str or list of str, optional
When `a` is a quantity, it can be a string or a sequence of strings, which is interpreted as an order the quantity
should be sorted. The default is None.
stable : bool, optional
Whether to use a stable sorting algorithm. The default is True.
descending : bool, optional
Whether to sort in descending order. The default is False.
Returns
-------
res : ndarray, Quantity
Sorted copy of the input array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [3, 1, 2] * u.meter
>>> u.math.sort(a)
"""
return _fun_keep_unit_unary(jnp.sort, a, axis=axis, kind=kind, order=order, stable=stable, descending=descending, **kwargs)
@set_module_as('saiunit.math')
def max(
a: Union[jax.Array, Quantity],
axis: Optional[int] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, Quantity]] = None,
where: Optional[jax.Array] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Return the maximum of a quantity or an array or maximum along an axis.
Parameters
----------
a : array_like, Quantity
Array or quantity containing numbers whose maximum is desired.
axis : int or None, optional
Axis or axes along which to operate. By default, flattened input is used.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this
option, the result will broadcast correctly against the input array.
initial : scalar, optional
The minimum value of an output element. Must be present to allow computation on empty slice.
See `numpy.ufunc.reduce`.
where : array_like, optional
Values of True indicate to calculate the ufunc at that position, values of False indicate to leave the value in the
output alone.
Returns
-------
res : ndarray, Quantity
Maximum of `a`. If `axis` is None, the result is a scalar value. If `axis` is given, the result is an array of
dimension `a.ndim - 1`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> u.math.max(a)
"""
return _fun_keep_unit_unary(jnp.max, a, axis=axis, keepdims=keepdims, initial=initial, where=where, **kwargs)
@set_module_as('saiunit.math')
def min(
a: Union[jax.Array, Quantity],
axis: Optional[int] = None,
keepdims: bool = False,
initial: Optional[Union[int, float, Quantity]] = None,
where: Optional[jax.Array] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Return the minimum of a quantity or an array or minimum along an axis.
Parameters
----------
a : array_like, Quantity
Array or quantity containing numbers whose minimum is desired.
axis : int or None, optional
Axis or axes along which to operate. By default, flattened input is used.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this
option, the result will broadcast correctly against the input array.
initial : scalar, optional
The maximum value of an output element. Must be present to allow computation on empty slice.
See `numpy.ufunc.reduce`.
where : array_like, optional
Values of True indicate to calculate the ufunc at that position, values of False indicate to leave the value in the
output alone.
Returns
-------
res : ndarray, Quantity
Minimum of `a`. If `axis` is None, the result is a scalar value. If `axis` is given, the result is an array of
dimension `a.ndim - 1`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> u.math.min(a)
"""
return _fun_keep_unit_unary(jnp.min, a, axis=axis, keepdims=keepdims, initial=initial, where=where, **kwargs)
amax = max
amin = min
@set_module_as('saiunit.math')
def diagonal(
a: Union[jax.Array, Quantity],
offset: int = 0,
axis1: int = 0,
axis2: int = 1,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Return specified diagonals.
Parameters
----------
a : array_like, Quantity
Array from which the diagonals are taken.
offset : int, optional
Offset of the diagonal from the main diagonal. Can be positive or negative. Defaults to main diagonal (0).
axis1 : int, optional
Axis to be used as the first axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to first
axis (0).
axis2 : int, optional
Axis to be used as the second axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to
second axis (1).
Returns
-------
res : ndarray
The extracted diagonals. The shape of the output is determined by considering the shape of the input array with
the specified axis removed.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2], [3, 4]] * u.second
>>> u.math.diagonal(a)
"""
return _fun_keep_unit_unary(jnp.diagonal, a, offset=offset, axis1=axis1, axis2=axis2, **kwargs)
@set_module_as('saiunit.math')
def ravel(
a: Union[jax.Array, Quantity],
order: str = 'C',
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Return a contiguous flattened quantity or array.
Parameters
----------
a : array_like, Quantity
Input array. The elements in `a` are read in the order specified by `order`, and packed as a 1-D array.
order : {'C', 'F', 'A', 'K'}, optional
The elements of `a` are read using this index order. 'C' means to index the elements in row-major, C-style order,
with the last axis index changing fastest, back to the first axis index changing slowest. 'F' means to index the
elements in column-major, Fortran-style order, with the first index changing fastest, and the last index changing
slowest. 'A' means to read the elements in Fortran-like index order if `a` is Fortran contiguous in memory, C-like
order otherwise. 'K' means to read the elements in the order they occur in memory, except for reversing the data
when strides are negative. By default, 'C' index order is used.
Returns
-------
res : ndarray, Quantity
The flattened quantity or array. The shape of the output is the same as `a`, but the array is 1-D.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2], [3, 4]] * u.meter
>>> u.math.ravel(a)
"""
return _fun_keep_unit_unary(jnp.ravel, a, order=order, **kwargs)
@set_module_as('saiunit.math')
def flatten(
x: jax.typing.ArrayLike | Quantity,
start_axis: Optional[int] = None,
end_axis: Optional[int] = None,
**kwargs,
) -> jax.Array | Quantity:
"""Flattens input by reshaping it into a one-dimensional tensor.
If ``start_dim`` or ``end_dim`` are passed, only dimensions starting
with ``start_dim`` and ending with ``end_dim`` are flattened.
The order of elements in input is unchanged.
.. note::
Flattening a zero-dimensional tensor will return a one-dimensional view.
Parameters
----------
x: Array, Quantity
The input array.
start_axis: int
the first dim to flatten
end_axis: int
the last dim to flatten
Returns
-------
out: Array, Quantity
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2], [3, 4]] * u.second
>>> u.math.flatten(a)
"""
shape = x.shape
ndim = x.ndim
if ndim == 0:
ndim = 1
if start_axis is None:
start_axis = 0
elif start_axis < 0:
start_axis = ndim + start_axis
if end_axis is None:
end_axis = ndim - 1
elif end_axis < 0:
end_axis = ndim + end_axis
end_axis += 1
if start_axis < 0 or start_axis > ndim:
raise ValueError(f'start_axis {start_axis} is out of size.')
if end_axis < 0 or end_axis > ndim:
raise ValueError(f'end_axis {end_axis} is out of size.')
new_shape = shape[:start_axis] + (np.prod(shape[start_axis: end_axis], dtype=int, **kwargs),) + shape[end_axis:]
return _fun_keep_unit_unary(jnp.reshape, x, shape=new_shape, **kwargs)
@set_module_as('saiunit.math')
def unflatten(
x: jax.typing.ArrayLike | Quantity,
axis: int,
sizes: Sequence[int],
**kwargs,
) -> jax.Array | Quantity:
"""
Expands a dimension of the input tensor over multiple dimensions.
Args:
x: input tensor.
axis: Dimension to be unflattened, specified as an index into ``x.shape``.
sizes: New shape of the unflattened dimension. One of its elements can be -1
in which case the corresponding output dimension is inferred.
Otherwise, the product of ``sizes`` must equal ``input.shape[dim]``.
Returns:
A tensor with the same data as ``input``, but with ``dim`` split into multiple dimensions.
The returned tensor has one more dimension than the input tensor.
The returned tensor shares the same underlying data with this tensor.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3, 4, 5, 6] * u.meter
>>> u.math.unflatten(a, 0, (2, 3))
"""
if x.ndim <= axis:
raise ValueError(
f'unflatten requires "axis" to be less than x.ndim, '
f'but got axis={axis} and x.ndim={x.ndim}.'
)
shape = x.shape
new_shape = shape[:axis] + tuple(sizes) + shape[axis + 1:]
return _fun_keep_unit_unary(jnp.reshape, x, shape=new_shape, **kwargs)
@set_module_as('saiunit.math')
def remove_diag(x: jax.typing.ArrayLike | Quantity, **kwargs) -> jax.Array | Quantity:
"""Remove the diagonal of the matrix.
Parameters
----------
x: Array, Quantity
The matrix with the shape of `(M, N)`.
Returns
-------
arr: Array, Quantity
The matrix without diagonal which has the shape of `(M, N-1)`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] * u.second
>>> u.math.remove_diag(a)
"""
x = maybe_custom_array(x)
unit = UNITLESS
if isinstance(x, Quantity):
unit = x.unit
x = x.mantissa
if x.ndim != 2:
raise ValueError(f'Only support 2D matrix, while we got a {x.ndim}D array.')
eyes = jnp.fill_diagonal(jnp.ones(x.shape, dtype=bool, **kwargs), False, **kwargs)
x = jnp.reshape(x[eyes], (x.shape[0], x.shape[1] - 1), **kwargs)
if unit.is_unitless:
return x
return Quantity(x, unit=unit)
# ---------- selection
@set_module_as('saiunit.math')
def choose(
a: Union[jax.Array, Quantity],
choices: Sequence[Union[jax.Array, Quantity]],
mode: str = 'raise',
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Construct a quantity or an array from an index array and a set of arrays to choose from.
Parameters
----------
a : array_like, Quantity
This array must be an integer array of the same shape as `choices`. The elements of `a` are used to select elements
from `choices`.
choices : sequence of array_like
Choice arrays. `a` and all `choices` must be broadcastable to the same shape.
mode : {'raise', 'wrap', 'clip'}, optional
Specifies how indices outside [0, n-1] will be treated:
- 'raise' : raise an error (default)
- 'wrap' : wrap around
- 'clip' : clip to the range [0, n-1]
Returns
-------
res : ndarray, Quantity
The constructed array. The shape is identical to the shape of `a`, and the data type is the data type of `choices`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> choices = [jnp.array([1, 2, 3]), jnp.array([4, 5, 6])]
>>> u.math.choose(jnp.array([0, 1, 0]), choices)
"""
return _fun_keep_unit_unary(jnp.choose, a, choices=choices, mode=mode, **kwargs)
@set_module_as('saiunit.math')
def diagflat(
v: Union[jax.Array, Quantity],
k: int = 0,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Create a two-dimensional a quantity or array with the flattened input as a diagonal.
Parameters
----------
v : array_like, Quantity
Input data, which is flattened and set as the `k`-th diagonal of the output.
k : int, optional
Diagonal in question. The default is 0.
Returns
-------
res : ndarray, Quantity
The 2-D output array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> u.math.diagflat(a)
"""
return _fun_keep_unit_unary(jnp.diagflat, v, k=k, **kwargs)
# math funcs keep unit (unary)
# ----------------------------
def _fun_keep_unit_unary(func, x, *args, **kwargs):
x = maybe_custom_array(x)
args, kwargs = maybe_custom_array_tree((args, kwargs))
if isinstance(x, Quantity):
return Quantity(func(x.mantissa, *args, **kwargs), unit=x.unit)
else:
return func(x, *args, **kwargs)
[docs]
def astype(
x: Union[jax.typing.ArrayLike, Quantity],
dtype: jax.typing.DTypeLike
) -> Union[jax.Array, Quantity]:
"""
Copy of the array, cast to a specified type.
Parameters
----------
x : array_like, Quantity
Input array.
dtype : dtype
Typecode or data-type to which the array is cast.
Returns
-------
out : ndarray, Quantity
A copy of the array, cast to a specified type.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1, 2, 3] * u.second
>>> u.math.astype(a, jnp.float32)
"""
return _fun_keep_unit_unary(jnp.astype, x, dtype)
@set_module_as('saiunit.math')
def real(x: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the real part of the complex argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1 + 2j, 3 + 4j] * u.second
>>> u.math.real(a)
"""
return _fun_keep_unit_unary(jnp.real, x, **kwargs)
@set_module_as('saiunit.math')
def imag(x: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the imaginary part of the complex argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1 + 2j, 3 + 4j] * u.second
>>> u.math.imag(a)
"""
return _fun_keep_unit_unary(jnp.imag, x, **kwargs)
@set_module_as('saiunit.math')
def conj(x: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the complex conjugate of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1 + 2j, 3 + 4j] * u.second
>>> u.math.conj(a)
"""
return _fun_keep_unit_unary(jnp.conj, x, **kwargs)
@set_module_as('saiunit.math')
def conjugate(x: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the complex conjugate of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1 + 2j, 3 + 4j] * u.second
>>> u.math.conjugate(a)
"""
return _fun_keep_unit_unary(jnp.conjugate, x, **kwargs)
@set_module_as('saiunit.math')
def negative(x: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the negative of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, -2, 3] * u.meter
>>> u.math.negative(a)
"""
return _fun_keep_unit_unary(jnp.negative, x, **kwargs)
@set_module_as('saiunit.math')
def positive(x: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the positive of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, -2, 3] * u.meter
>>> u.math.positive(a)
"""
return _fun_keep_unit_unary(jnp.positive, x, **kwargs)
@set_module_as('saiunit.math')
def abs(x: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the absolute value of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [-1, -2, 3] * u.meter
>>> u.math.abs(a)
"""
return _fun_keep_unit_unary(jnp.abs, x, **kwargs)
@set_module_as('saiunit.math')
def sum(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
dtype: Union[jax.typing.DTypeLike, None] = None,
keepdims: bool = False,
initial: Union[jax.typing.ArrayLike, Quantity, None] = None,
where: Union[jax.typing.ArrayLike, None] = None,
promote_integers: bool = True,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the sum of the array elements.
Parameters
----------
x : array_like, Quantity
Input array.
axis : None or int or tuple of ints, optional
Axis or axes along which a sum is performed. The default,
axis=None, will sum all of the elements of the input array. If
axis is negative it counts from the last to the first axis.
If axis is a tuple of ints, a sum is performed on all of the axes
specified in the tuple instead of a single axis or all the axes as
before.
dtype : dtype, optional
The type of the returned array and of the accumulator in which the
elements are summed. The dtype of `a` is used by default unless `a`
has an integer dtype of less precision than the default platform
integer. In that case, if `a` is signed then the platform integer
is used while if `a` is unsigned then an unsigned integer of the
same precision as the platform integer is used.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
If the default value is passed, then `keepdims` will not be
passed through to the `sum` method of sub-classes of
`ndarray`, however any non-default value will be. If the
sub-class' method does not implement `keepdims` any
exceptions will be raised.
initial : scalar, optional
Starting value for the sum. See `~numpy.ufunc.reduce` for details.
where : array_like of bool, optional
Elements to include in the sum. See `~numpy.ufunc.reduce` for details.
promote_integers : bool, optional
If True, and if the accumulator is an integer type, then the
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.0, 2.0, 3.0] * u.second
>>> u.math.sum(a)
"""
if initial is not None:
initial = Quantity(initial).in_unit(get_unit(x)).mantissa
return _fun_keep_unit_unary(jnp.sum,
x,
axis=axis,
dtype=dtype,
keepdims=keepdims,
initial=initial,
where=where,
promote_integers=promote_integers, **kwargs)
@set_module_as('saiunit.math')
def nancumsum(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
dtype: Union[jax.typing.DTypeLike, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the cumulative sum of the array elements, ignoring NaNs.
Parameters
----------
x : array_like, Quantity
Input array.
axis : int, optional
Axis along which the cumulative sum is computed. The default
(None) is to compute the cumsum over the flattened array.
dtype : dtype, optional
Type of the returned array and of the accumulator in which the
elements are summed. If `dtype` is not specified, it defaults
to the dtype of `a`, unless `a` has an integer dtype with a
precision less than that of the default platform integer. In
that case, the default platform integer is used.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 3.0] * u.meter
>>> u.math.nancumsum(a)
"""
return _fun_keep_unit_unary(jnp.nancumsum, x, axis=axis, dtype=dtype, **kwargs)
@set_module_as('saiunit.math')
def nansum(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
dtype: Union[jax.typing.DTypeLike, None] = None,
keepdims: bool = False,
initial: Union[jax.typing.ArrayLike, Quantity, None] = None,
where: Union[jax.typing.ArrayLike, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the sum of the array elements, ignoring NaNs.
Parameters
----------
x : array_like, Quantity
Input array.
axis : {int, tuple of int, None}, optional
Axis or axes along which the sum is computed. The default is to compute the
sum of the flattened array.
dtype : data-type, optional
The type of the returned array and of the accumulator in which the
elements are summed. By default, the dtype of `a` is used. An
exception is when `a` has an integer type with less precision than
the platform (u)intp. In that case, the default will be either
(u)int32 or (u)int64 depending on whether the platform is 32 or 64
bits. For inexact inputs, dtype must be inexact.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original `a`.
If the value is anything but the default, then
`keepdims` will be passed through to the `mean` or `sum` methods
of sub-classes of `ndarray`. If the sub-classes methods
does not implement `keepdims` any exceptions will be raised.
initial : scalar, Quantity, optional
Starting value for the sum. See `~numpy.ufunc.reduce` for details.
where : array_like of bool, optional
Elements to include in the sum. See `~numpy.ufunc.reduce` for details.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 3.0] * u.meter
>>> u.math.nansum(a)
"""
if initial is not None:
initial = Quantity(initial).in_unit(get_unit(x)).mantissa
return _fun_keep_unit_unary(jnp.nansum,
x,
axis=axis,
dtype=dtype,
keepdims=keepdims,
initial=initial,
where=where, **kwargs)
@set_module_as('saiunit.math')
def cumsum(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
dtype: Union[jax.typing.DTypeLike, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the cumulative sum of the array elements.
Parameters
----------
x : array_like, Quantity
Input array.
axis : int, optional
Axis along which the cumulative sum is computed. The default
(None) is to compute the cumsum over the flattened array.
dtype : dtype, optional
Type of the returned array and of the accumulator in which the
elements are summed. If `dtype` is not specified, it defaults
to the dtype of `a`, unless `a` has an integer dtype with a
precision less than that of the default platform integer. In
that case, the default platform integer is used.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.second
>>> u.math.cumsum(a)
"""
return _fun_keep_unit_unary(jnp.cumsum, x, axis=axis, dtype=dtype, **kwargs)
@set_module_as('saiunit.math')
def ediff1d(
x: Quantity | jax.typing.ArrayLike,
to_end: jax.typing.ArrayLike | Quantity = None,
to_begin: jax.typing.ArrayLike | Quantity = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the differences between consecutive elements of the array.
Parameters
----------
x : array_like, Quantity
Input array.
to_end : array_like, optional
Number(s) to append at the end of the returned differences.
to_begin : array_like, optional
Number(s) to prepend at the beginning of the returned differences.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 4, 7] * u.meter
>>> u.math.ediff1d(a)
"""
x_unit = get_unit(x)
if to_end is not None:
to_end = Quantity(to_end).in_unit(x_unit).mantissa
if to_begin is not None:
to_begin = Quantity(to_begin).in_unit(x_unit).mantissa
return _fun_keep_unit_unary(jnp.ediff1d, x, to_end=to_end, to_begin=to_begin, **kwargs)
@set_module_as('saiunit.math')
def absolute(x: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the absolute value of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [-1.0, -2.0, 3.0] * u.meter
>>> u.math.absolute(a)
"""
return _fun_keep_unit_unary(jnp.absolute, x, **kwargs)
@set_module_as('saiunit.math')
def fabs(x: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the absolute value of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [-1.0, -2.0, 3.0] * u.meter
>>> u.math.fabs(a)
"""
return _fun_keep_unit_unary(jnp.fabs, x, **kwargs)
@set_module_as('saiunit.math')
def median(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
overwrite_input: bool = False,
keepdims: bool = False,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the median of the array elements.
Parameters
----------
x : array_like, Quantity
Input array.
axis : {int, sequence of int, None}, optional
Axis or axes along which the medians are computed. The default
is to compute the median along a flattened version of the array.
A sequence of axes is supported since version 1.9.0.
overwrite_input : bool, optional
If True, then allow use of memory of input array `a` for
calculations. The input array will be modified by the call to
`median`. This will save memory when you do not need to preserve
the contents of the input array. Treat the input as undefined,
but it will probably be fully or partially sorted. Default is
False. If `overwrite_input` is ``True`` and `a` is not already an
`ndarray`, an error will be raised.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original `arr`.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3, 4, 5] * u.second
>>> u.math.median(a)
"""
return _fun_keep_unit_unary(jnp.median, x, axis=axis, overwrite_input=overwrite_input, keepdims=keepdims, **kwargs)
@set_module_as('saiunit.math')
def nanmin(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
keepdims: bool = False,
initial: Union[jax.typing.ArrayLike, Quantity, None] = None,
where: Union[jax.typing.ArrayLike, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the minimum of the array elements, ignoring NaNs.
Parameters
----------
x : array_like, Quantity
Input array.
axis : {int, tuple of int, None}, optional
Axis or axes along which the minimum is computed. The default is to compute
the minimum of the flattened array.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original `a`.
If the value is anything but the default, then
`keepdims` will be passed through to the `min` method
of sub-classes of `ndarray`. If the sub-classes methods
does not implement `keepdims` any exceptions will be raised.
initial : scalar, optional
The maximum value of an output element. Must be present to allow
computation on empty slice. See `~numpy.ufunc.reduce` for details.
where : array_like of bool, optional
Elements to compare for the minimum. See `~numpy.ufunc.reduce`
for details.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 3.0] * u.meter
>>> u.math.nanmin(a)
"""
if initial is not None:
initial = Quantity(initial).in_unit(get_unit(x)).mantissa
return _fun_keep_unit_unary(jnp.nanmin, x, axis=axis, keepdims=keepdims, initial=initial, where=where, **kwargs)
@set_module_as('saiunit.math')
def nanmax(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
keepdims: bool = False,
initial: Union[jax.typing.ArrayLike, None] = None,
where: Union[jax.typing.ArrayLike, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the maximum of the array elements, ignoring NaNs.
Parameters
----------
x : array_like, Quantity
Input array.
axis : {int, tuple of int, None}, optional
Axis or axes along which the minimum is computed. The default is to compute
the minimum of the flattened array.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original `a`.
If the value is anything but the default, then
`keepdims` will be passed through to the `min` method
of sub-classes of `ndarray`. If the sub-classes methods
does not implement `keepdims` any exceptions will be raised.
initial : scalar, optional
The maximum value of an output element. Must be present to allow
computation on empty slice. See `~numpy.ufunc.reduce` for details.
where : array_like of bool, optional
Elements to compare for the minimum. See `~numpy.ufunc.reduce`
for details.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 3.0] * u.meter
>>> u.math.nanmax(a)
"""
if initial is not None:
initial = Quantity(initial).in_unit(get_unit(x)).mantissa
return _fun_keep_unit_unary(jnp.nanmax, x, axis=axis, keepdims=keepdims, initial=initial, where=where, **kwargs)
@set_module_as('saiunit.math')
def ptp(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
keepdims: bool = False,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the range of the array elements (maximum - minimum).
Parameters
----------
x : array_like, Quantity
Input array.
axis : None or int or tuple of ints, optional
Axis along which to find the peaks. By default, flatten the
array. `axis` may be negative, in
which case it counts from the last to the first axis.
If this is a tuple of ints, a reduction is performed on multiple
axes, instead of a single axis or all the axes as before.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
If the default value is passed, then `keepdims` will not be
passed through to the `ptp` method of sub-classes of
`ndarray`, however any non-default value will be. If the
sub-class' method does not implement `keepdims` any
exceptions will be raised.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3, 4, 5] * u.meter
>>> u.math.ptp(a)
"""
return _fun_keep_unit_unary(jnp.ptp, x, axis=axis, keepdims=keepdims, **kwargs)
@set_module_as('saiunit.math')
def average(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
weights: Union[jax.typing.ArrayLike, None] = None,
returned: bool = False,
keepdims: bool = False,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the weighted average of the array elements.
Parameters
----------
x : array_like, Quantity
Input array.
axis : None or int or tuple of ints, optional
Axis or axes along which to average `a`. The default,
axis=None, will average over all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
If axis is a tuple of ints, averaging is performed on all of the axes
specified in the tuple instead of a single axis or all the axes as
before.
weights : array_like, optional
An array of weights associated with the values in `a`. Each value in
`a` contributes to the average according to its associated weight.
The weights array can either be 1-D (in which case its length must be
the size of `a` along the given axis) or of the same shape as `a`.
If `weights=None`, then all data in `a` are assumed to have a
weight equal to one. The 1-D calculation is::
avg = sum(a * weights) / sum(weights)
The only constraint on `weights` is that `sum(weights)` must not be 0.
returned : bool, optional
Default is `False`. If `True`, the tuple (`average`, `sum_of_weights`)
is returned, otherwise only the average is returned.
If `weights=None`, `sum_of_weights` is equivalent to the number of
elements over which the average is taken.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original `a`.
*Note:* `keepdims` will not work with instances of `numpy.matrix`
or other classes whose methods do not support `keepdims`.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.0, 2.0, 3.0] * u.second
>>> u.math.average(a)
"""
return _fun_keep_unit_unary(jnp.average, x, axis=axis, weights=weights, returned=returned, keepdims=keepdims, **kwargs)
@set_module_as('saiunit.math')
def mean(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
dtype: Union[jax.typing.DTypeLike, None] = None,
keepdims: bool = False, *,
where: Union[jax.typing.ArrayLike, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the mean of the array elements.
Parameters
----------
x : array_like, Quantity
Input array.
axis : None or int or tuple of ints, optional
Axis or axes along which the means are computed. The default is to
compute the mean of the flattened array.
If this is a tuple of ints, a mean is performed over multiple axes,
instead of a single axis or all the axes as before.
dtype : data-type, optional
Type to use in computing the mean. For integer inputs, the default
is `float64`; for floating point inputs, it is the same as the
input dtype.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
If the default value is passed, then `keepdims` will not be
passed through to the `mean` method of sub-classes of
`ndarray`, however any non-default value will be. If the
sub-class' method does not implement `keepdims` any
exceptions will be raised.
where : array_like of bool, optional
Elements to include in the mean.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.0, 2.0, 3.0] * u.second
>>> u.math.mean(a)
"""
return _fun_keep_unit_unary(jnp.mean, x, axis=axis, dtype=dtype, keepdims=keepdims, where=where, **kwargs)
@set_module_as('saiunit.math')
def std(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
dtype: Union[jax.typing.DTypeLike, None] = None,
ddof: int = 0,
keepdims: bool = False, *,
where: Union[jax.typing.ArrayLike, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the standard deviation of the array elements.
Parameters
----------
x : array_like, Quantity
Input array.
axis : None or int or tuple of ints, optional
Axis or axes along which the standard deviation is computed. The
default is to compute the standard deviation of the flattened array.
If this is a tuple of ints, a standard deviation is performed over
multiple axes, instead of a single axis or all the axes as before.
dtype : dtype, optional
Type to use in computing the standard deviation. For arrays of
integer type the default is float64, for arrays of float types it is
the same as the array type.
ddof : int, optional
Means Delta Degrees of Freedom. The divisor used in calculations
is ``N - ddof``, where ``N`` represents the number of elements.
By default `ddof` is zero.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
If the default value is passed, then `keepdims` will not be
passed through to the `std` method of sub-classes of
`ndarray`, however any non-default value will be. If the
sub-class' method does not implement `keepdims` any
exceptions will be raised.
where : array_like of bool, optional
Elements to include in the standard deviation.
See `~numpy.ufunc.reduce` for details.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.0, 2.0, 3.0] * u.meter
>>> u.math.std(a)
"""
return _fun_keep_unit_unary(jnp.std, x, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where, **kwargs)
@set_module_as('saiunit.math')
def nanmedian(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, tuple[int, ...], None] = None,
overwrite_input: bool = False,
keepdims: bool = False,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the median of the array elements, ignoring NaNs.
Parameters
----------
x : array_like, Quantity
Input array.
axis : {int, sequence of int, None}, optional
Axis or axes along which the medians are computed. The default
is to compute the median along a flattened version of the array.
A sequence of axes is supported since version 1.9.0.
overwrite_input : bool, optional
If True, then allow use of memory of input array `a` for
calculations. The input array will be modified by the call to
`median`. This will save memory when you do not need to preserve
the contents of the input array. Treat the input as undefined,
but it will probably be fully or partially sorted. Default is
False. If `overwrite_input` is ``True`` and `a` is not already an
`ndarray`, an error will be raised.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original `a`.
If this is anything but the default value it will be passed
through (in the special case of an empty array) to the
`mean` function of the underlying array. If the array is
a sub-class and `mean` does not have the kwarg `keepdims` this
will raise a RuntimeError.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 3.0] * u.second
>>> u.math.nanmedian(a)
"""
return _fun_keep_unit_unary(jnp.nanmedian, x, axis=axis, overwrite_input=overwrite_input, keepdims=keepdims, **kwargs)
@set_module_as('saiunit.math')
def nanmean(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
dtype: Union[jax.typing.DTypeLike, None] = None,
keepdims: bool = False, *,
where: Union[jax.typing.ArrayLike, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the mean of the array elements, ignoring NaNs.
Parameters
----------
x : array_like, Quantity
Input array.
axis : None or int or tuple of ints, optional
Axis or axes along which the means are computed. The default is to
compute the mean of the flattened array.
If this is a tuple of ints, a mean is performed over multiple axes,
instead of a single axis or all the axes as before.
dtype : data-type, optional
Type to use in computing the mean. For integer inputs, the default
is `float64`; for floating point inputs, it is the same as the
input dtype.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
If the default value is passed, then `keepdims` will not be
passed through to the `mean` method of sub-classes of
`ndarray`, however any non-default value will be. If the
sub-class' method does not implement `keepdims` any
exceptions will be raised.
where : array_like of bool, optional
Elements to include in the mean.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 3.0] * u.meter
>>> u.math.nanmean(a)
"""
return _fun_keep_unit_unary(jnp.nanmean, x, axis=axis, dtype=dtype, keepdims=keepdims, where=where, **kwargs)
@set_module_as('saiunit.math')
def nanstd(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Union[int, Sequence[int], None] = None,
dtype: Union[jax.typing.DTypeLike, None] = None,
ddof: int = 0,
keepdims: bool = False, *,
where: Union[jax.typing.ArrayLike, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the standard deviation of the array elements, ignoring NaNs.
Parameters
----------
x : array_like, Quantity
Input array.
axis : None or int or tuple of ints, optional
Axis or axes along which the standard deviation is computed. The
default is to compute the standard deviation of the flattened array.
If this is a tuple of ints, a standard deviation is performed over
multiple axes, instead of a single axis or all the axes as before.
dtype : dtype, optional
Type to use in computing the standard deviation. For arrays of
integer type the default is float64, for arrays of float types it is
the same as the array type.
ddof : int, optional
Means Delta Degrees of Freedom. The divisor used in calculations
is ``N - ddof``, where ``N`` represents the number of elements.
By default `ddof` is zero.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.
If the default value is passed, then `keepdims` will not be
passed through to the `std` method of sub-classes of
`ndarray`, however any non-default value will be. If the
sub-class' method does not implement `keepdims` any
exceptions will be raised.
where : array_like of bool, optional
Elements to include in the standard deviation.
See `~numpy.ufunc.reduce` for details.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 3.0] * u.meter
>>> u.math.nanstd(a)
"""
return _fun_keep_unit_unary(jnp.nanstd, x, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims,
where=where, **kwargs)
@set_module_as('saiunit.math')
def diff(
x: Union[Quantity, jax.typing.ArrayLike],
n: int = 1,
axis: int = -1,
prepend: Union[jax.typing.ArrayLike, Quantity, None] = None,
append: Union[jax.typing.ArrayLike, Quantity, None] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the differences between consecutive elements of the array.
Parameters
----------
x : array_like, Quantity
Input array.
n : int, optional
The number of times values are differenced. If zero, the input
is returned as-is.
axis : int, optional
The axis along which the difference is taken, default is the
last axis.
prepend, append : array_like, optional
Values to prepend or append to `a` along axis prior to
performing the difference. Scalar values are expanded to
arrays with length 1 in the direction of axis and the shape
of the input array in along all other axes. Otherwise the
dimension and shape must match `a` except along axis.
Returns
-------
out : jax.Array, Quantity
Quantity if `x` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 4, 7] * u.meter
>>> u.math.diff(a)
"""
x_unit = get_unit(x)
if prepend is not None:
prepend = Quantity(prepend).in_unit(x_unit).mantissa
if append is not None:
append = Quantity(append).in_unit(x_unit).mantissa
return _fun_keep_unit_unary(jnp.diff, x, n=n, axis=axis, prepend=prepend, append=append, **kwargs)
@set_module_as('saiunit.math')
def rot90(
m: Union[jax.typing.ArrayLike, Quantity],
k: int = 1,
axes: Tuple[int, int] = (0, 1),
**kwargs,
) -> Union[
jax.Array, Quantity]:
"""
Rotate an array by 90 degrees in the plane specified by axes.
Rotation direction is from the first towards the second axis.
Parameters
----------
m : array_like, Quantity
Array of two or more dimensions.
k : integer
Number of times the array is rotated by 90 degrees.
axes : (2,) array_like
The array is rotated in the plane defined by the axes.
Axes must be different.
Returns
-------
y : ndarray, Quantity
A rotated view of `m`.
This is a quantity if `m` is a quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2], [3, 4]] * u.second
>>> u.math.rot90(a)
"""
return _fun_keep_unit_unary(jnp.rot90, m, k=k, axes=axes, **kwargs)
@set_module_as('saiunit.math')
def intersect1d(
ar1: Union[jax.typing.ArrayLike, Quantity],
ar2: Union[jax.typing.ArrayLike, Quantity],
assume_unique: bool = False,
return_indices: bool = False,
**kwargs,
) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]:
"""
Find the intersection of two arrays.
Return the sorted, unique values that are in both of the input arrays.
Parameters
----------
ar1, ar2 : array_like, Quantity
Input arrays. Will be flattened if not already 1D.
assume_unique : bool
If True, the input arrays are both assumed to be unique, which
can speed up the calculation. If True but ``ar1`` or ``ar2`` are not
unique, incorrect results and out-of-bounds indices could result.
Default is False.
return_indices : bool
If True, the indices which correspond to the intersection of the two
arrays are returned. The first instance of a value is used if there are
multiple. Default is False.
Returns
-------
intersect1d : ndarray, Quantity
Sorted 1D array of common and unique elements.
comm1 : ndarray
The indices of the first occurrences of the common values in `ar1`.
Only provided if `return_indices` is True.
comm2 : ndarray
The indices of the first occurrences of the common values in `ar2`.
Only provided if `return_indices` is True.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3, 4, 5] * u.second
>>> b = [3, 4, 5, 6, 7] * u.second
>>> u.math.intersect1d(a, b)
"""
ar1 = maybe_custom_array(ar1)
ar2 = maybe_custom_array(ar2)
fail_for_dimension_mismatch(ar1, ar2, 'intersect1d')
unit = UNITLESS
if isinstance(ar1, Quantity):
unit = ar1.unit
ar1 = ar1.in_unit(unit).mantissa if isinstance(ar1, Quantity) else ar1
ar2 = ar2.in_unit(unit).mantissa if isinstance(ar2, Quantity) else ar2
result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices, **kwargs)
if return_indices:
if unit.is_unitless:
return result
else:
return Quantity(result[0], unit=unit), result[1], result[2]
else:
if unit.is_unitless:
return result
else:
return Quantity(result, unit=unit)
@set_module_as('saiunit.math')
def nan_to_num(
x: Union[jax.typing.ArrayLike, Quantity],
nan: float | Quantity = None,
posinf: float | Quantity = None,
neginf: float | Quantity = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Replace NaN with zero and infinity with large finite numbers (default
behaviour) or with the numbers defined by the user using the `nan`,
`posinf` and/or `neginf` keywords.
If `x` is inexact, NaN is replaced by zero or by the user defined value in
`nan` keyword, infinity is replaced by the largest finite floating point
values representable by ``x.dtype`` or by the user defined value in
`posinf` keyword and -infinity is replaced by the most negative finite
floating point values representable by ``x.dtype`` or by the user defined
value in `neginf` keyword.
For complex dtypes, the above is applied to each of the real and
imaginary components of `x` separately.
If `x` is not inexact, then no replacements are made.
Parameters
----------
x : scalar, array_like or Quantity
Input data.
nan : int, float, optional
Value to be used to fill NaN values. If no value is passed
then NaN values will be replaced with 0.0.
posinf : int, float, optional
Value to be used to fill positive infinity values. If no value is
passed then positive infinity values will be replaced with a very
large number.
neginf : int, float, optional
Value to be used to fill negative infinity values. If no value is
passed then negative infinity values will be replaced with a very
small (or negative) number.
Returns
-------
out : ndarray, Quantity
`x`, with the non-finite values replaced. If `copy` is False, this may
be `x` itself.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, jnp.inf] * u.meter
>>> u.math.nan_to_num(a)
"""
x_unit = get_unit(x)
if isinstance(x, Quantity):
if nan is not None:
nan = Quantity(nan).in_unit(x_unit).mantissa
else:
nan = 0.0
if posinf is not None:
posinf = Quantity(posinf).in_unit(x_unit).mantissa
if neginf is not None:
neginf = Quantity(neginf).in_unit(x_unit).mantissa
r = jnp.nan_to_num(x.mantissa, nan=nan, posinf=posinf, neginf=neginf, **kwargs)
return r if x_unit.is_unitless else Quantity(r, unit=x_unit)
else:
nan = 0.0 if nan is None else nan
return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf, **kwargs)
@set_module_as('saiunit.math')
def trace(
a: Union[jax.Array, Quantity],
offset: int = 0,
axis1: int = 0,
axis2: int = 1,
dtype: Optional[jax.typing.DTypeLike] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Return the sum along diagonals of the array.
If `a` is 2-D, the sum along its diagonal with the given offset
is returned, i.e., the sum of elements ``a[i,i+offset]`` for all i.
If `a` has more than two dimensions, then the axes specified by axis1 and
axis2 are used to determine the 2-D sub-arrays whose traces are returned.
The shape of the resulting array is the same as that of `a` with `axis1`
and `axis2` removed.
Parameters
----------
a : array_like, Quantity
Input array, from which the diagonals are taken.
offset : int, optional
Offset of the diagonal from the main diagonal. Can be both positive
and negative. Defaults to 0.
axis1, axis2 : int, optional
Axes to be used as the first and second axis of the 2-D sub-arrays
from which the diagonals should be taken. Defaults are the first two
axes of `a`.
dtype : dtype, optional
Determines the data-type of the returned array and of the accumulator
where the elements are summed. If dtype has the value None and `a` is
of integer type of precision less than the default integer
precision, then the default integer precision is used. Otherwise,
the precision is the same as that of `a`.
Returns
-------
sum_along_diagonals : ndarray
If `a` is 2-D, the sum along the diagonal is returned. If `a` has
larger dimensions, then an array of sums along diagonals is returned.
This is a Quantity if `a` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [[1, 2], [3, 4]] * u.second
>>> u.math.trace(a)
"""
return _fun_keep_unit_unary(jnp.trace, a, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, **kwargs)
@set_module_as('saiunit.math')
def percentile(
a: Union[jax.Array, Quantity],
q: jax.typing.ArrayLike,
axis: Optional[Union[int, Tuple[int]]] = None,
method: str = 'linear',
keepdims: Optional[bool] = False,
**kwargs,
) -> jax.Array:
"""
Compute the q-th percentile of the data along the specified axis.
Returns the q-th percentile(s) of the array elements.
Parameters
----------
a : array_like, Quantity
Input array or Quantity.
q : array_like
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
method : str, optional
This parameter specifies the method to use for estimating the
percentile. There are many different methods, some unique to NumPy.
See the notes for explanation. The options sorted by their R type
as summarized in the H&F paper (Hyndman & Fan, 1996) are:
1. 'inverted_cdf'
2. 'averaged_inverted_cdf'
3. 'closest_observation'
4. 'interpolated_inverted_cdf'
5. 'hazen'
6. 'weibull'
7. 'linear' (default)
8. 'median_unbiased'
9. 'normal_unbiased'
The first three methods are discontinuous. NumPy further defines the
following discontinuous variations of the default 'linear' (7.) option:
* 'lower'
* 'higher',
* 'midpoint'
* 'nearest'
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one.
Returns
-------
out : jax.Array
Output array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3, 4, 5] * u.meter
>>> u.math.percentile(a, 50)
"""
if isinstance(q, Quantity):
if not q.is_unitless:
raise TypeError(
f'percentile requires "q" to be dimensionless (a percentage 0-100), '
f'but got q with unit={q.unit}. Pass a plain number or dimensionless Quantity for q.'
)
q = q.mantissa
return _fun_keep_unit_unary(
jnp.percentile, a, q=q, axis=axis, method=method, keepdims=keepdims,
**kwargs,
)
@set_module_as('saiunit.math')
def nanpercentile(
a: Union[jax.Array, Quantity],
q: jax.typing.ArrayLike,
axis: Optional[Union[int, Tuple[int]]] = None,
method: str = 'linear',
keepdims: Optional[bool] = False,
**kwargs,
) -> jax.Array:
"""
Compute the q-th percentile of the data along the specified axis, while ignoring nan values.
Returns the q-th percentile(s) of the array elements, while ignoring nan values.
Parameters
----------
a : array_like, Quantity
Input array or Quantity.
q : array_like, Quantity
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
method : str, optional
This parameter specifies the method to use for estimating the
percentile. There are many different methods, some unique to NumPy.
See the notes for explanation. The options sorted by their R type
as summarized in the H&F paper (Hyndman & Fan, 1996) are:
1. 'inverted_cdf'
2. 'averaged_inverted_cdf'
3. 'closest_observation'
4. 'interpolated_inverted_cdf'
5. 'hazen'
6. 'weibull'
7. 'linear' (default)
8. 'median_unbiased'
9. 'normal_unbiased'
The first three methods are discontinuous. NumPy further defines the
following discontinuous variations of the default 'linear' (7.) option:
* 'lower'
* 'higher',
* 'midpoint'
* 'nearest'
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one.
Returns
-------
out : jax.Array
Output array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 3.0, 4.0, 5.0] * u.meter
>>> u.math.nanpercentile(a, 50)
"""
if isinstance(q, Quantity):
if not q.is_unitless:
raise TypeError(
f'nanpercentile requires "q" to be dimensionless (a percentage 0-100), '
f'but got q with unit={q.unit}. Pass a plain number or dimensionless Quantity for q.'
)
q = q.mantissa
return _fun_keep_unit_unary(
jnp.nanpercentile, a, q=q, axis=axis, method=method, keepdims=keepdims,
**kwargs,
)
@set_module_as('saiunit.math')
def quantile(
a: Union[jax.Array, Quantity],
q: jax.typing.ArrayLike,
axis: Optional[Union[int, Tuple[int]]] = None,
method: str = 'linear',
keepdims: Optional[bool] = False,
**kwargs,
) -> jax.Array:
"""
Compute the q-th percentile of the data along the specified axis.
Returns the q-th percentile(s) of the array elements.
Parameters
----------
a : array_like, Quantity
Input array or Quantity.
q : array_like, Quantity
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
method : str, optional
This parameter specifies the method to use for estimating the
percentile. There are many different methods, some unique to NumPy.
See the notes for explanation. The options sorted by their R type
as summarized in the H&F paper (Hyndman & Fan, 1996) are:
1. 'inverted_cdf'
2. 'averaged_inverted_cdf'
3. 'closest_observation'
4. 'interpolated_inverted_cdf'
5. 'hazen'
6. 'weibull'
7. 'linear' (default)
8. 'median_unbiased'
9. 'normal_unbiased'
The first three methods are discontinuous. NumPy further defines the
following discontinuous variations of the default 'linear' (7.) option:
* 'lower'
* 'higher',
* 'midpoint'
* 'nearest'
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one.
Returns
-------
out : jax.Array
Output array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3, 4, 5] * u.meter
>>> u.math.quantile(a, 0.5)
"""
if isinstance(q, Quantity):
if not q.is_unitless:
raise TypeError(
f'quantile requires "q" to be dimensionless (a quantile between 0 and 1), '
f'but got q with unit={q.unit}. Pass a plain number or dimensionless Quantity for q.'
)
q = q.mantissa
return _fun_keep_unit_unary(
jnp.quantile, a, q=q, axis=axis, method=method, keepdims=keepdims,
**kwargs,
)
@set_module_as('saiunit.math')
def nanquantile(
a: Union[jax.Array, Quantity],
q: jax.typing.ArrayLike,
axis: Optional[Union[int, Tuple[int]]] = None,
method: str = 'linear',
keepdims: Optional[bool] = False,
**kwargs,
) -> jax.Array:
"""
Compute the q-th percentile of the data along the specified axis, while ignoring nan values.
Returns the q-th percentile(s) of the array elements, while ignoring nan values.
Parameters
----------
a : array_like, Quantity
Input array or Quantity.
q : array_like, Quantity
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
method : str, optional
This parameter specifies the method to use for estimating the
percentile. There are many different methods, some unique to NumPy.
See the notes for explanation. The options sorted by their R type
as summarized in the H&F paper (Hyndman & Fan, 1996) are:
1. 'inverted_cdf'
2. 'averaged_inverted_cdf'
3. 'closest_observation'
4. 'interpolated_inverted_cdf'
5. 'hazen'
6. 'weibull'
7. 'linear' (default)
8. 'median_unbiased'
9. 'normal_unbiased'
The first three methods are discontinuous. NumPy further defines the
following discontinuous variations of the default 'linear' (7.) option:
* 'lower'
* 'higher',
* 'midpoint'
* 'nearest'
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one.
Returns
-------
out : jax.Array
Output array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 3.0, 4.0, 5.0] * u.meter
>>> u.math.nanquantile(a, 0.5)
"""
if isinstance(q, Quantity):
if not q.is_unitless:
raise TypeError(
f'nanquantile requires "q" to be dimensionless (a quantile between 0 and 1), '
f'but got q with unit={q.unit}. Pass a plain number or dimensionless Quantity for q.'
)
q = q.mantissa
return _fun_keep_unit_unary(
jnp.nanquantile, a, q=q, axis=axis, method=method, keepdims=keepdims,
**kwargs,
)
# math funcs keep unit (binary)
# -----------------------------
def _fun_keep_unit_binary(func, x1, x2, *args, **kwargs):
x1 = maybe_custom_array(x1)
x2 = maybe_custom_array(x2)
args, kwargs = maybe_custom_array_tree((args, kwargs))
if isinstance(x1, Quantity) and isinstance(x2, Quantity):
return Quantity(func(x1.mantissa, x2.in_unit(x1.unit).mantissa, *args, **kwargs), unit=x1.unit)
elif isinstance(x1, Quantity):
if not x1.is_unitless:
raise TypeError(
f'Expected "x1" to be dimensionless when "x2" is a plain array, '
f'but got x1 with unit={x1.unit}. '
f'Either pass a Quantity for x2 with matching units, or strip the unit from x1.'
)
return func(x1.mantissa, x2, *args, **kwargs)
elif isinstance(x2, Quantity):
if not x2.is_unitless:
raise TypeError(
f'Expected "x2" to be dimensionless when "x1" is a plain array, '
f'but got x2 with unit={x2.unit}. '
f'Either pass a Quantity for x1 with matching units, or strip the unit from x2.'
)
return func(x1, x2.mantissa, *args, **kwargs)
else:
return func(x1, x2, *args, **kwargs)
@set_module_as('saiunit.math')
def fmod(x1: Union[Quantity, jax.typing.ArrayLike],
x2: Union[Quantity, jax.Array], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the element-wise remainder of division.
Parameters
----------
x1: array_like, Quantity
Input array.
x2: array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [5, 6, 7] * u.second
>>> b = [2, 3, 4] * u.second
>>> u.math.fmod(a, b)
"""
return _fun_keep_unit_binary(jnp.fmod, x1, x2, **kwargs)
@set_module_as('saiunit.math')
def mod(x1: Union[Quantity, jax.typing.ArrayLike], x2: Union[Quantity, jax.Array], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the element-wise modulus of division.
Parameters
----------
x1: array_like, Quantity
Input array.
x2: array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [5, 6, 7] * u.meter
>>> b = [2, 3, 4] * u.meter
>>> u.math.mod(a, b)
"""
return _fun_keep_unit_binary(jnp.mod, x1, x2, **kwargs)
@set_module_as('saiunit.math')
def copysign(
x1: Union[Quantity, jax.typing.ArrayLike],
x2: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return a copy of the first array elements with the sign of the second array.
Parameters
----------
x1: array_like, Quantity
Input array.
x2: array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [-1.0, 2.0] * u.meter
>>> b = [1.0, -3.0] * u.meter
>>> u.math.copysign(a, b)
"""
x2 = x2.mantissa if isinstance(x2, Quantity) else x2
return _fun_keep_unit_unary(jnp.copysign, x1, x2, **kwargs)
@set_module_as('saiunit.math')
def maximum(
x1: Union[Quantity, jax.typing.ArrayLike],
x2: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Element-wise maximum of array elements.
Parameters
----------
x1: array_like, Quantity
Input array.
x2: array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 3, 5] * u.second
>>> b = [2, 2, 4] * u.second
>>> u.math.maximum(a, b)
"""
return _fun_keep_unit_binary(jnp.maximum, x1, x2, **kwargs)
@set_module_as('saiunit.math')
def minimum(
x1: Union[Quantity, jax.typing.ArrayLike],
x2: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Element-wise minimum of array elements.
Parameters
----------
x1: array_like, Quantity
Input array.
x2: array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 3, 5] * u.second
>>> b = [2, 2, 4] * u.second
>>> u.math.minimum(a, b)
"""
return _fun_keep_unit_binary(jnp.minimum, x1, x2, **kwargs)
@set_module_as('saiunit.math')
def fmax(
x1: Union[Quantity, jax.typing.ArrayLike],
x2: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Element-wise maximum of array elements ignoring NaNs.
Parameters
----------
x1: array_like, Quantity
Input array.
x2: array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 5.0] * u.meter
>>> b = [2.0, 2.0, 4.0] * u.meter
>>> u.math.fmax(a, b)
"""
return _fun_keep_unit_binary(jnp.fmax, x1, x2, **kwargs)
@set_module_as('saiunit.math')
def fmin(
x1: Union[Quantity, jax.typing.ArrayLike],
x2: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Element-wise minimum of array elements ignoring NaNs.
Parameters
----------
x1: array_like, Quantity
Input array.
x2: array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1.0, jnp.nan, 5.0] * u.meter
>>> b = [2.0, 2.0, 4.0] * u.meter
>>> u.math.fmin(a, b)
"""
return _fun_keep_unit_binary(jnp.fmin, x1, x2, **kwargs)
@set_module_as('saiunit.math')
def lcm(
x1: Union[Quantity, jax.typing.ArrayLike],
x2: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the least common multiple of `x1` and `x2`.
Parameters
----------
x1: array_like, Quantity
Input array.
x2: array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.array([4, 6]) * u.second
>>> b = jnp.array([6, 8]) * u.second
>>> u.math.lcm(a.astype(jnp.int64), b.astype(jnp.int64))
"""
return _fun_keep_unit_binary(jnp.lcm, x1, x2, **kwargs)
@set_module_as('saiunit.math')
def gcd(x1: Union[Quantity, jax.typing.ArrayLike],
x2: Union[Quantity, jax.typing.ArrayLike], **kwargs) -> Union[Quantity, jax.Array]:
"""
Return the greatest common divisor of `x1` and `x2`.
Parameters
----------
x1: array_like, Quantity
Input array.
x2: array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Quantity if `x1` and `x2` are Quantities that have the same unit, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.array([4, 6]) * u.second
>>> b = jnp.array([6, 8]) * u.second
>>> u.math.gcd(a.astype(jnp.int64), b.astype(jnp.int64))
"""
return _fun_keep_unit_binary(jnp.gcd, x1, x2, **kwargs)
@set_module_as('saiunit.math')
def add(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Add arguments element-wise.
Parameters
----------
x, y : array_like, Quantity
The arrays to be added.
If ``x.shape != y.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
add : ndarray or scalar
The sum of `x` and `y`, element-wise.
This is a scalar if both `x` and `y` are scalars.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3] * u.meter
>>> b = [4, 5, 6] * u.meter
>>> u.math.add(a, b)
"""
return _fun_keep_unit_binary(jnp.add, x, y, **kwargs)
@set_module_as('saiunit.math')
def subtract(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Subtract arguments, element-wise.
Parameters
----------
x, y : array_like
The arrays to be subtracted from each other.
If ``x.shape != y.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
subtract : ndarray
The difference of `x` and `y`, element-wise.
This is a scalar if both `x` and `y` are scalars.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [4, 5, 6] * u.meter
>>> b = [1, 2, 3] * u.meter
>>> u.math.subtract(a, b)
"""
return _fun_keep_unit_binary(jnp.subtract, x, y, **kwargs)
@set_module_as('saiunit.math')
def remainder(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Returns the element-wise remainder of division.
Computes the remainder complementary to the `floor_divide` function. It is
equivalent to the Python modulus operator``x1 % x2`` and has the same sign
as the divisor `x2`. The MATLAB function equivalent to ``np.remainder``
is ``mod``.
Parameters
----------
x : array_like, Quantity
Dividend array.
y : array_like, Quantity
Divisor array.
If ``x1.shape != x2.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
out : ndarray, Quantity
The element-wise remainder of the quotient ``floor_divide(x1, x2)``.
This is a scalar if both `x1` and `x2` are scalars.
This is a Quantity if division of `x1` by `x2` is not dimensionless.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [5, 6, 7] * u.second
>>> b = [2, 3, 4] * u.second
>>> u.math.remainder(a, b)
"""
return _fun_keep_unit_binary(jnp.remainder, x, y, **kwargs)
@set_module_as('saiunit.math')
def nextafter(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return the next floating-point value after x towards y, element-wise.
Parameters
----------
x : array_like, Quantity
Values to find the next representable value of.
y : array_like, Quantity
The direction where to look for the next representable value of `x`.
If ``x.shape != y.shape``, they must be broadcastable to a common
shape (which becomes the shape of the output).
Returns
-------
out : ndarray or scalar
The next representable values of `x` in the direction of `y`.
This is a scalar if both `x` and `y` are scalars.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.0, 2.0] * u.meter
>>> b = [2.0, 1.0] * u.meter
>>> u.math.nextafter(a, b)
"""
return _fun_keep_unit_binary(jnp.nextafter, x, y, **kwargs)
# math funcs keep unit (n-ary)
# ----------------------------
@set_module_as('saiunit.math')
def interp(
x: Union[Quantity, jax.typing.ArrayLike],
xp: Union[Quantity, jax.typing.ArrayLike],
fp: Union[Quantity, jax.typing.ArrayLike],
left: Union[Quantity, jax.typing.ArrayLike] = None,
right: Union[Quantity, jax.typing.ArrayLike] = None,
period: Union[Quantity, jax.typing.ArrayLike] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
One-dimensional linear interpolation.
Parameters
----------
x : array_like, Quantity
The x-coordinates at which to evaluate the interpolated values.
xp : array_like, Quantity
The x-coordinates of the data points, must be increasing.
fp : array_like, Quantity
The y-coordinates of the data points.
left : array_like, Quantity, optional
Value to return for ``x < xp[0]``.
right : array_like, Quantity, optional
Value to return for ``x > xp[-1]``.
period : array_like, Quantity, optional
A period for the x-coordinates.
Returns
-------
out : jax.Array, Quantity
Quantity if `fp` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> x = [1, 2, 3] * u.second
>>> xp = [0, 1, 2, 3, 4] * u.second
>>> fp = [0, 1, 2, 3, 4] * u.meter
>>> u.math.interp(x, xp, fp)
"""
x_unit = get_unit(x)
fp, y_unit = split_mantissa_unit(fp)
x, xp, fp, left, right, period = (
x.mantissa if isinstance(x, Quantity) else x,
Quantity(xp).in_unit(x_unit).mantissa,
fp,
Quantity(left).in_unit(x_unit).mantissa if left is not None else left,
Quantity(right).in_unit(x_unit).mantissa if right is not None else right,
Quantity(period).in_unit(x_unit).mantissa if period is not None else period
)
r = jnp.interp(x, xp=xp, fp=fp, left=left, right=right, period=period, **kwargs)
return maybe_decimal(Quantity(r, unit=y_unit))
@set_module_as('saiunit.math')
def clip(
a: Union[Quantity, jax.typing.ArrayLike],
a_min: Union[Quantity, jax.typing.ArrayLike],
a_max: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Clip (limit) the values in an array.
Parameters
----------
a : array_like, Quantity
Array containing elements to clip.
a_min : array_like, Quantity
Minimum value. If None, clipping is not performed on the lower interval edge.
a_max : array_like, Quantity
Maximum value. If None, clipping is not performed on the upper interval edge.
Returns
-------
out : jax.Array, Quantity
Quantity if `a` is a Quantity, else an array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3, 4, 5] * u.meter
>>> u.math.clip(a, 2 * u.meter, 4 * u.meter)
"""
a_unit = get_unit(a)
if a_min is not None:
a_min = Quantity(a_min).in_unit(a_unit).mantissa
if a_max is not None:
a_max = Quantity(a_max).in_unit(a_unit).mantissa
return _fun_keep_unit_unary(jnp.clip, a, min=a_min, max=a_max, **kwargs)
@set_module_as('saiunit.math')
def histogram(
x: Union[jax.Array, Quantity],
bins: jax.typing.ArrayLike = 10,
range: Optional[Sequence[jax.typing.ArrayLike | Quantity]] = None,
weights: Optional[jax.typing.ArrayLike] = None,
density: Optional[bool] = None,
**kwargs,
) -> Tuple[jax.Array, jax.Array | Quantity]:
"""
Compute the histogram of a set of data.
Parameters
----------
x : array_like, Quantity
Input data. The histogram is computed over the flattened array.
bins : int or sequence of scalars or str, optional
If `bins` is an int, it defines the number of equal-width
bins in the given range (10, by default). If `bins` is a
sequence, it defines a monotonically increasing array of bin edges,
including the rightmost edge, allowing for non-uniform bin widths.
If `bins` is a string, it defines the method used to calculate the
optimal bin width, as defined by `histogram_bin_edges`.
range : (float, float), (Quantity, Quantity) optional
The lower and upper range of the bins. If not provided, range
is simply ``(a.min(), a.max())``. Values outside the range are
ignored. The first element of the range must be less than or
equal to the second. `range` affects the automatic bin
computation as well. While bin width is computed to be optimal
based on the actual data within `range`, the bin count will fill
the entire range including portions containing no data.
weights : array_like, optional
An array of weights, of the same shape as `a`. Each value in
`a` only contributes its associated weight towards the bin count
(instead of 1). If `density` is True, the weights are
normalized, so that the integral of the density over the range
remains 1.
density : bool, optional
If ``False``, the result will contain the number of samples in
each bin. If ``True``, the result is the value of the
probability *density* function at the bin, normalized such that
the *integral* over the range is 1. Note that the sum of the
histogram values will not be equal to 1 unless bins of unity
width are chosen; it is not a probability *mass* function.
Returns
-------
hist : array
The values of the histogram. See `density` and `weights` for a
description of the possible semantics.
bin_edges : array of dtype float
Return the bin edges ``(length(hist)+1)``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 1, 3, 2] * u.second
>>> hist, bin_edges = u.math.histogram(a)
"""
unit = UNITLESS
if isinstance(x, Quantity):
unit = x.unit
x = x.mantissa
if range is not None:
range = (
Quantity(range[0]).in_unit(unit).mantissa,
Quantity(range[1]).in_unit(unit).mantissa
)
hist, bin_edges = jnp.histogram(x, bins, range=range, weights=weights, density=density, **kwargs)
if unit.is_unitless:
return hist, bin_edges
return hist, Quantity(bin_edges, unit=unit)
@set_module_as('saiunit.math')
def compress(
condition: jax.Array,
a: Union[jax.Array, Quantity],
axis: Optional[int] = None,
*,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike] = None,
**kwargs,
) -> Union[jax.Array, Quantity]:
"""
Return selected slices of a quantity or an array along given axis.
Parameters
----------
condition : array_like, Quantity
An array of boolean values that selects which slices to return. If the shape of condition is not the same as `a`,
it must be broadcastable to `a`.
a : array_like, Quantity
Array from which to extract a part.
axis : int or None, optional
The axis along which to take slices. If axis is None, `condition` must be a 1-D array with the same length as `a`.
If axis is an integer, `condition` must be broadcastable to the same shape as `a` along all axes except `axis`.
size : int, optional
The length of the returned axis. By default, the length of the input array along the axis is used.
fill_value : scalar, optional
The value to use for elements in the output array that are not selected. If None, the output array has the same
type as `a` and is filled with zeros.
Returns
-------
res : ndarray, Quantity
A new array that has the same number of dimensions as `a`, and the same shape as `a` with axis `axis` removed.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [1, 2, 3, 4] * u.meter
>>> u.math.compress(jnp.array([0, 1, 1, 0]), a)
"""
if isinstance(condition, Quantity):
if not condition.is_unitless:
raise TypeError(
f'compress requires "condition" to be dimensionless (a boolean mask), '
f'but got condition with unit={condition.unit}. '
f'Strip the unit from condition before passing it to compress.'
)
condition = condition.mantissa
a_unit = get_unit(a)
if fill_value is not None:
fill_value = Quantity(fill_value).in_unit(a_unit).mantissa
else:
fill_value = 0
return _fun_keep_unit_unary(functools.partial(jnp.compress, condition),
a, axis=axis, size=size, fill_value=fill_value, **kwargs)
@set_module_as('saiunit.math')
def extract(
condition: jax.Array,
arr: Union[jax.Array, Quantity],
*,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike | Quantity] = None,
**kwargs,
) -> jax.Array | Quantity:
"""
Return the elements of an array that satisfy some condition.
Parameters
----------
condition : array_like, Quantity
An array of boolean values that selects which elements to extract.
arr : array_like, Quantity
The array from which to extract elements.
size: int
optional static size for output. Must be specified in order for ``extract``
to be compatible with JAX transformations like :func:`~jax.jit` or :func:`~jax.vmap`.
fill_value: array_like
if ``size`` is specified, fill padded entries with this value (default: 0).
Returns
-------
res : ndarray
The extracted elements. The shape of `res` is the same as that of `condition`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.array([1, 2, 3]) * u.meter
>>> u.math.extract(a.mantissa > 1, a)
"""
if isinstance(condition, Quantity):
if not condition.is_unitless:
raise TypeError(
f'extract requires "condition" to be dimensionless (a boolean mask), '
f'but got condition with unit={condition.unit}. '
f'Strip the unit from condition before passing it to extract.'
)
condition = condition.mantissa
a_unit = get_unit(arr)
if fill_value is not None:
fill_value = Quantity(fill_value).in_unit(a_unit).mantissa
else:
fill_value = 0
return _fun_keep_unit_unary(functools.partial(jnp.extract, condition),
arr, size=size, fill_value=fill_value, **kwargs)
@set_module_as('saiunit.math')
def take(
a: Union[Quantity, jax.typing.ArrayLike],
indices: Union[Quantity, jax.typing.ArrayLike],
axis: Optional[int] = None,
mode: Optional[str] = None,
unique_indices: bool = False,
indices_are_sorted: bool = False,
fill_value: Optional[Union[Quantity, jax.typing.ArrayLike]] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Take elements from an array along an axis.
When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.
Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
----------
a : array_like (Ni..., M, Nk...)
The source array.
indices : array_like (Nj...)
The indices of the values to extract.
Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
mode : string, default="fill"
Out-of-bounds indexing mode. The default mode="fill" returns invalid values
(e.g. NaN) for out-of bounds indices (see also ``fill_value`` below).
For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`.
fill_value : optional
The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored
otherwise. Defaults to NaN for inexact types, the largest negative value for
signed types, the largest positive value for unsigned types, and True for booleans.
unique_indices : bool, default=False
If True, the implementation will assume that the indices are unique,
which can result in more efficient execution on some backends.
indices_are_sorted : bool, default=False
If True, the implementation will assume that the indices are sorted in
ascending order, which can lead to more efficient execution on some backends.
Returns
-------
out : ndarray (Ni..., Nj..., Nk...)
The returned array has the same type as `a`.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = [4, 3, 5, 7, 6, 8] * u.second
>>> u.math.take(a, jnp.array([0, 1, 4]))
"""
if isinstance(a, Quantity):
return a.take(indices, axis=axis, mode=mode, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value)
else:
return jnp.take(a, indices, axis=axis, mode=mode, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value, **kwargs)
@set_module_as('saiunit.math')
def select(
condlist: list[Union[jax.typing.ArrayLike]],
choicelist: Union[Quantity, jax.typing.ArrayLike],
default: int = 0,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Return an array drawn from elements in choicelist, depending on conditions.
Parameters
----------
condlist : list of bool ndarrays
The list of conditions which determine from which array in `choicelist`
the output elements are taken. When multiple conditions are satisfied,
the first one encountered in `condlist` is used.
choicelist : list of ndarrays or Quantity
The list of arrays from which the output elements are taken. It has
to be of the same length as `condlist`.
default : scalar, optional
The element inserted in `output` when all conditions evaluate to False.
Returns
-------
output : ndarray, Quantity
The output at position m is the m-th element of the array in
`choicelist` where the m-th element of the corresponding array in
`condlist` is True.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> conds = [jnp.array([True, False, True]), jnp.array([False, True, False])]
>>> choices = [[1, 2, 3] * u.second, [4, 5, 6] * u.second]
>>> u.math.select(conds, choices, default=0)
"""
for cond in condlist:
if isinstance(cond, Quantity):
raise TypeError(
f'select requires all elements of "condlist" to be plain boolean arrays, '
f'but got a Quantity with unit={cond.unit}. '
f'Strip units from all condition arrays before passing them to select.'
)
return _fun_keep_unit_sequence(functools.partial(jnp.select, condlist), choicelist, default=default, **kwargs)
@set_module_as('saiunit.math')
def where(condition, x=None, y=None, /, *, size=None, fill_value=None, **kwargs):
"""
Return elements chosen from `x` or `y` depending on `condition`.
.. note::
When only `condition` is provided, this function is a shorthand for
``np.asarray(condition).nonzero()``. Using `nonzero` directly should be
preferred, as it behaves correctly for subclasses. The rest of this
documentation covers only the case where all three arguments are
provided.
Parameters
----------
condition : array_like, bool,
Where True, yield `x`, otherwise yield `y`.
x, y : array_like, Quantity
Values from which to choose. `x`, `y` and `condition` need to be
broadcastable to some shape.
size : int, optional
The length of the output array. If `size` is not None, the output array will have the length of `size`.
fill_value : scalar, Quantity, optional
The value to use for missing values. If `fill_value` is not None, the output array will have the length of `size`.
Returns
-------
out : ndarray
An array with elements from `x` where `condition` is True, and elements
from `y` elsewhere.
See Also
--------
choose
nonzero : The function that is called when x and y are omitted
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1, 2, 3, 4, 5] * u.meter
>>> u.math.where(a > 3 * u.meter, a, 0 * u.meter)
"""
if isinstance(condition, Quantity):
raise TypeError(
f'where requires "condition" to be a plain boolean array, not a Quantity. '
f'Got condition with unit={condition.unit}. '
f'Strip the unit from condition before passing it to where.'
)
if x is None and y is None:
if isinstance(fill_value, Quantity):
raise TypeError(
f'where requires "fill_value" to be a plain scalar when x and y are not provided, '
f'but got fill_value with unit={fill_value.unit}.'
)
return jnp.where(condition, size=size, fill_value=fill_value, **kwargs)
if size is not None or fill_value is not None:
raise ValueError(
f'where does not support "size" or "fill_value" when x and y are provided. '
f'Got size={size}, fill_value={fill_value}.'
)
if isinstance(x, Quantity) and isinstance(y, Quantity):
y = y.in_unit(x.unit)
return Quantity(jnp.where(condition, x.mantissa, y.mantissa, **kwargs), unit=x.unit)
elif isinstance(x, Quantity):
if not x.is_unitless:
raise TypeError(
f'where requires "x" to be dimensionless when "y" is a plain array, '
f'but got x with unit={x.unit}. '
f'Either pass a Quantity for y with matching units, or strip the unit from x.'
)
x = x.mantissa
elif isinstance(y, Quantity):
if not y.is_unitless:
raise TypeError(
f'where requires "y" to be dimensionless when "x" is a plain array, '
f'but got y with unit={y.unit}. '
f'Either pass a Quantity for x with matching units, or strip the unit from y.'
)
y = y.mantissa
return jnp.where(condition, x, y, **kwargs)
@set_module_as('saiunit.math')
def unique(
a: Union[jax.Array, Quantity],
return_index: bool = False,
return_inverse: bool = False,
return_counts: bool = False,
axis: Optional[int] = None,
*,
equal_nan: bool = False,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike, Quantity] = None,
**kwargs,
) -> Sequence[jax.Array | Quantity] | jax.Array | Quantity:
"""
Find the unique elements of a quantity or an array.
Parameters
----------
a : array_like, Quantity
Input array.
return_index : bool, optional
If True, also return the indices of `a` (along the specified axis, if provided) that result in the unique array.
return_inverse : bool, optional
If True, also return the indices of the unique array (for the specified axis, if provided)
that can be used to reconstruct `a`.
return_counts : bool, optional
If True, also return the number of times each unique item appears in `a`.
axis : int, optional
The axis along which to operate. If None, the array is flattened before use. Default is None.
equal_nan : bool, optional
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered equal to each other in the unique array.
size : int, optional
The length of the output array. If `size` is not None, the output array will have the length of `size`.
fill_value : scalar, optional
The value to use for missing values. If `fill_value` is not None, the output array will have the length of `size`.
Returns
-------
res : ndarray, Quantity
The sorted unique values.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [0, 1, 2, 1, 0] * u.second
>>> u.math.unique(a)
"""
a_unit = get_unit(a)
if fill_value is not None:
fill_value = Quantity(fill_value).in_unit(a_unit).mantissa
if isinstance(a, Quantity):
result = jnp.unique(a.mantissa,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
equal_nan=equal_nan,
size=size,
fill_value=fill_value, **kwargs)
if isinstance(result, tuple):
output = []
output.append(Quantity(result[0], unit=a_unit))
for r in result[1:]:
output.append(r)
return tuple(output)
else:
return Quantity(result, unit=a_unit)
else:
return jnp.unique(a,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
equal_nan=equal_nan,
size=size,
fill_value=fill_value, **kwargs)
@set_module_as('saiunit.math')
def round(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
**kwargs,
) -> jax.Array | Quantity:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
decimals: int
Number of decimal places to round to (default is 0).
Returns
-------
out : jax.Array, Quantity
Rounded values. Quantity if `x` is a Quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.2, 2.7, 3.1] * u.meter
>>> u.math.round(a)
"""
return _fun_keep_unit_unary(jnp.round, x, decimals=decimals, **kwargs)
@set_module_as('saiunit.math')
def around(
x: Union[Quantity, jax.typing.ArrayLike],
decimals: int = 0,
**kwargs,
) -> jax.Array | Quantity:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
decimals : int, optional
Number of decimal places to round to (default is 0).
Returns
-------
out : jax.Array, Quantity
Rounded values. Quantity if `x` is a Quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.2, 2.7, 3.1] * u.second
>>> u.math.around(a)
"""
return _fun_keep_unit_unary(jnp.around, x, decimals=decimals, **kwargs)
@set_module_as('saiunit.math')
def rint(
x: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""
Round an array to the nearest integer.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Rounded values. Quantity if `x` is a Quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.2, 2.7, 3.1] * u.meter
>>> u.math.rint(a)
"""
return _fun_keep_unit_unary(jnp.rint, x, **kwargs)
@set_module_as('saiunit.math')
def floor(
x: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> jax.Array | Quantity:
"""
Return the floor of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Floor values. Quantity if `x` is a Quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.2, 2.7, 3.1] * u.meter
>>> u.math.floor(a)
"""
return _fun_keep_unit_unary(jnp.floor, x, **kwargs)
@set_module_as('saiunit.math')
def ceil(
x: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> jax.Array | Quantity:
"""
Return the ceiling of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Ceiling values. Quantity if `x` is a Quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.2, 2.7, 3.1] * u.meter
>>> u.math.ceil(a)
"""
return _fun_keep_unit_unary(jnp.ceil, x, **kwargs)
@set_module_as('saiunit.math')
def trunc(
x: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> jax.Array | Quantity:
"""
Return the truncated value of the argument.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Truncated values. Quantity if `x` is a Quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.7, -2.3, 3.9] * u.meter
>>> u.math.trunc(a)
"""
return _fun_keep_unit_unary(jnp.trunc, x, **kwargs)
@set_module_as('saiunit.math')
def fix(
x: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> jax.Array | Quantity:
"""
Return the nearest integer towards zero.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
out : jax.Array, Quantity
Values rounded towards zero. Quantity if `x` is a Quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.7, -2.3, 3.9] * u.meter
>>> u.math.fix(a)
"""
return _fun_keep_unit_unary(jnp.trunc, x, **kwargs)
@set_module_as('saiunit.math')
def modf(
x: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Tuple[jax.Array | Quantity, jax.Array | Quantity]:
"""
Return the fractional and integer parts of the array elements.
Parameters
----------
x : array_like, Quantity
Input array.
Returns
-------
The fractional and integral parts of the input, both with the same dimension.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = [1.5, 2.7] * u.second
>>> frac, intg = u.math.modf(a)
"""
if isinstance(x, Quantity):
return jax.tree.map(lambda y: Quantity(y, unit=x.unit), jnp.modf(x.mantissa, **kwargs))
return jnp.modf(x, **kwargs)
@set_module_as('saiunit.math')
def gather(input: jax.Array | Quantity, dim: int, index: jax.Array, **kwargs):
"""
Gather values along an axis specified by dim, according to index.
JAX implementation of ``torch.gather``.
Parameters
----------
input : jax.Array, Quantity
The source array or Quantity.
dim : int
The axis along which to index.
index : jax.Array
The indices of elements to gather.
Returns
-------
out : jax.Array, Quantity
Array with the gathered elements. Quantity if `input` is a Quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.array([[1, 2], [3, 4]]) * u.mV
>>> index = jnp.array([[0, 0], [1, 0]])
>>> u.math.gather(a, 1, index)
"""
input = maybe_custom_array(input)
# Normalize dim to be positive
if isinstance(input, Quantity):
unit = input.unit
input = input.mantissa
else:
unit = UNITLESS
ndim = input.ndim
if dim < 0:
dim = ndim + dim
# Create index arrays for all dimensions
idx_shape = index.shape
indices = []
for i in range(ndim):
if i == dim:
# Use the provided index for the gather dimension
indices.append(index)
else:
# Create meshgrid indices for other dimensions
shape = [1] * ndim
shape[i] = idx_shape[i] if i < len(idx_shape) else input.shape[i]
idx = jnp.arange(input.shape[i], **kwargs).reshape(shape)
# Broadcast to match index shape
broadcast_shape = list(idx_shape)
if i < len(idx_shape):
broadcast_shape[i] = input.shape[i]
indices.append(jnp.broadcast_to(idx, idx_shape, **kwargs))
result = input[tuple(indices)]
if unit.is_unitless:
return result
return Quantity(result, unit=unit)