# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import builtins
from typing import Union, Sequence, Callable
import jax
import numpy as np
from jax import lax
from jax._src.typing import Shape
from saiunit._base_getters import has_same_unit, maybe_decimal
from saiunit._base_quantity import Quantity
from saiunit._misc import set_module_as, maybe_custom_array, maybe_custom_array_tree
from saiunit.math._fun_keep_unit import _fun_keep_unit_unary, _fun_keep_unit_binary
__all__ = [
# sequence inputs
# array manipulation
'slice', 'dynamic_slice', 'dynamic_update_slice', 'gather',
'index_take', 'slice_in_dim', 'index_in_dim', 'dynamic_slice_ind_dim', 'dynamic_index_in_dim',
'dynamic_update_slice_in_dim', 'dynamic_update_index_in_dim',
'sort', 'sort_key_val',
# math funcs keep unit (unary)
'neg',
'cummax', 'cummin', 'cumsum',
'scatter', 'scatter_add', 'scatter_sub', 'scatter_mul', 'scatter_min', 'scatter_max', 'scatter_apply',
# math funcs keep unit (binary)
'sub', 'complex', 'pad',
# math funcs keep unit (n-ary)
'clamp',
# type conversion
'convert_element_type', 'bitcast_convert_type',
# math funcs keep unit (return Quantity and index)
'approx_max_k', 'approx_min_k', 'top_k',
# math funcs only accept unitless (unary) can return Quantity
# broadcasting arrays
'broadcast', 'broadcast_in_dim', 'broadcast_to_rank',
]
# array manipulation
@set_module_as('saiunit.math')
def slice(
operand: Union[Quantity, jax.typing.ArrayLike],
start_indices: Sequence[int],
limit_indices: Sequence[int],
strides: Sequence[int] | None = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Wraps XLA's `Slice
<https://www.tensorflow.org/xla/operation_semantics#slice>`_
operator.
Args:
operand: an array to slice
start_indices: a sequence of ``operand.ndim`` start indices.
limit_indices: a sequence of ``operand.ndim`` limit indices.
strides: an optional sequence of ``operand.ndim`` strides.
Returns:
The sliced array
Examples:
Here are some examples of simple two-dimensional slices:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> lax.slice(x, (1, 0), (3, 2))
Array([[4, 5],
[8, 9]], dtype=int32)
>>> lax.slice(x, (0, 0), (3, 4), (1, 2))
Array([[ 0, 2],
[ 4, 6],
[ 8, 10]], dtype=int32)
These two examples are equivalent to the following Python slicing syntax:
>>> x[1:3, 0:2]
Array([[4, 5],
[8, 9]], dtype=int32)
>>> x[0:3, 0:4:2]
Array([[ 0, 2],
[ 4, 6],
[ 8, 10]], dtype=int32)
"""
return _fun_keep_unit_unary(lax.slice, operand, start_indices, limit_indices, strides, **kwargs)
@set_module_as('saiunit.math')
def dynamic_slice(
operand: Union[Quantity, jax.typing.ArrayLike],
start_indices: jax.typing.ArrayLike | Sequence[jax.typing.ArrayLike],
slice_sizes: Shape,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Wraps XLA's `DynamicSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
operator.
Args:
operand: an array to slice.
start_indices: a list of scalar indices, one per dimension. These values
may be dynamic.
slice_sizes: the size of the slice. Must be a sequence of non-negative
integers with length equal to `ndim(operand)`. Inside a JIT compiled
function, only static values are supported (all JAX arrays inside JIT
must have statically known size).
Returns:
An array containing the slice.
Examples:
Here is a simple two-dimensional dynamic slice:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3))
Array([[ 5, 6, 7],
[ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested slice
overruns the bounds of the array; in this case the start index is adjusted to
return a slice of the requested size:
>>> dynamic_slice(x, (1, 1), (2, 4))
Array([[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
"""
return _fun_keep_unit_unary(lax.dynamic_slice, operand, start_indices, slice_sizes, **kwargs)
@set_module_as('saiunit.math')
def dynamic_update_slice(
operand: Union[Quantity, jax.typing.ArrayLike],
update: Union[Quantity, jax.typing.ArrayLike],
start_indices: jax.typing.ArrayLike | Sequence[jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Wraps XLA's `DynamicUpdateSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
operator.
Args:
operand: an array to slice.
update: an array containing the new values to write onto `operand`.
start_indices: a list of scalar indices, one per dimension.
Returns:
An array containing the slice.
Examples:
Here is an example of updating a one-dimensional slice update:
>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice(x, y, (2,))
Array([0., 0., 1., 1., 1., 0.], dtype=float32)
If the update slice is too large to fit in the array, the start
index will be adjusted to make it fit
>>> dynamic_update_slice(x, y, (3,))
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice(x, y, (5,))
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
Here is an example of a two-dimensional slice update:
>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 2))
>>> dynamic_update_slice(x, y, (1, 2))
Array([[0., 0., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 0.]], dtype=float32)
"""
return _fun_keep_unit_binary(lax.dynamic_update_slice, operand, update, start_indices, **kwargs)
@set_module_as('saiunit.math')
def gather(
operand: Union[Quantity, jax.typing.ArrayLike],
start_indices: jax.typing.ArrayLike,
dimension_numbers: jax.lax.GatherDimensionNumbers,
slice_sizes: Shape,
*,
unique_indices: bool = False,
indices_are_sorted: bool = False,
mode: str | jax.lax.GatherScatterMode | None = None,
fill_value: Union[Quantity, jax.typing.ArrayLike] = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Gather operator.
Wraps `XLA's Gather operator
<https://www.tensorflow.org/xla/operation_semantics#gather>`_.
:func:`gather` is a low-level operator with complicated semantics, and most JAX
users will never need to call it directly. Instead, you should prefer using
`Numpy-style indexing`_, and/or :func:`jax.numpy.ndarray.at`, perhaps in combination
with :func:`jax.vmap`.
Args:
operand: an array from which slices should be taken
start_indices: the indices at which slices should be taken
dimension_numbers: a `lax.GatherDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices` and the output relate.
slice_sizes: the size of each slice. Must be a sequence of non-negative
integers with length equal to `ndim(operand)`.
indices_are_sorted: whether `indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements gathered from ``operand`` are
guaranteed not to overlap with each other. If ``True``, this may improve
performance on some backends. JAX does not check this promise: if
the elements overlap the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to ``'clip'``,
indices are clamped so that the slice is within bounds, and when
set to ``'fill'`` or ``'drop'`` gather returns a slice full of
``fill_value`` for the affected slice. The behavior for out-of-bounds
indices when set to ``'promise_in_bounds'`` is implementation-defined.
fill_value: the fill value to return for out-of-bounds slices when `mode`
is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for inexact types,
the largest negative value for signed types, the largest positive value
for unsigned types, and ``True`` for booleans.
Returns:
An array containing the gather output.
Examples:
As mentioned above, you should basically never use :func:`gather` directly,
and instead use NumPy-style indexing expressions to gather values from
arrays.
For example, here is how you can extract values at particular indices using
straightforward indexing semantics, which will lower to XLA's Gather operator:
>>> import jax.numpy as jnp
>>> x = jnp.array([10, 11, 12])
>>> indices = jnp.array([0, 1, 1, 2, 2, 2])
>>> x[indices]
Array([10, 11, 11, 12, 12, 12], dtype=int32)
For control over settings like ``indices_are_sorted``, ``unique_indices``, ``mode``,
and ``fill_value``, you can use the :attr:`jax.numpy.ndarray.at` syntax:
>>> x.at[indices].get(indices_are_sorted=True, mode="promise_in_bounds")
Array([10, 11, 11, 12, 12, 12], dtype=int32)
By comparison, here is the equivalent function call using :func:`gather` directly,
which is not something typical users should ever need to do:
>>> from jax import lax
>>> lax.gather(x, indices[:, None], slice_sizes=(1,),
... dimension_numbers=lax.GatherDimensionNumbers(
... offset_dims=(),
... collapsed_slice_dims=(0,),
... start_index_map=(0,)),
... indices_are_sorted=True,
... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
Array([10, 11, 11, 12, 12, 12], dtype=int32)
"""
operand = maybe_custom_array(operand)
fill_value = maybe_custom_array(fill_value)
if isinstance(operand, Quantity) and isinstance(fill_value, Quantity):
return maybe_decimal(
Quantity(lax.gather(operand.mantissa, start_indices, dimension_numbers, slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value.in_unit(operand.unit).mantissa, **kwargs),
unit=operand.unit)
)
elif isinstance(operand, Quantity):
if fill_value is not None:
raise ValueError('fill_value must be a Quantity if operand is a Quantity')
return maybe_decimal(
Quantity(lax.gather(operand.mantissa, start_indices, dimension_numbers, slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, **kwargs), unit=operand.unit)
)
elif isinstance(fill_value, Quantity):
raise ValueError('fill_value must be None if operand is not a Quantity')
return lax.gather(operand, start_indices, dimension_numbers, slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value, **kwargs)
@set_module_as('saiunit.math')
def index_take(
src: Union[Quantity, jax.typing.ArrayLike],
idxs: jax.typing.ArrayLike,
axes: Sequence[int],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Take elements from an array at the given indices along the given axes.
Parameters
----------
src : array_like or Quantity
The source array from which to take elements.
idxs : array_like
The indices of elements to extract.
axes : sequence of int
The axes along which to index.
Returns
-------
result : jax.Array or Quantity
The gathered elements. If ``src`` is a ``Quantity``, the result
preserves the same unit.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> x = jnp.array([[1., 2., 3.], [4., 5., 6.]]) * u.meter
>>> idx = jnp.array([0, 2])
>>> result = sulax.index_take(x, (idx,), axes=(1,))
>>> result.mantissa
Array([[1., 3.],
[4., 6.]], dtype=float32)
"""
return _fun_keep_unit_unary(lax.index_take, src, idxs, axes, **kwargs)
@set_module_as('saiunit.math')
def slice_in_dim(
operand: Union[Quantity, jax.typing.ArrayLike],
start_index: int | None,
limit_index: int | None,
stride: int = 1,
axis: int = 0,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Convenience wrapper around :func:`lax.slice` applying to only one dimension.
This is effectively equivalent to ``operand[..., start_index:limit_index:stride]``
with the indexing applied on the specified axis.
Args:
operand: an array to slice.
start_index: an optional start index (defaults to zero)
limit_index: an optional end index (defaults to operand.shape[axis])
stride: an optional stride (defaults to 1)
axis: the axis along which to apply the slice (defaults to 0)
Returns:
An array containing the slice.
Examples:
Here is a one-dimensional example:
>>> x = jnp.arange(4)
>>> lax.slice_in_dim(x, 1, 3)
Array([1, 2], dtype=int32)
Here are some two-dimensional examples:
>>> x = jnp.arange(12).reshape(4, 3)
>>> x
Array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3)
Array([[3, 4, 5],
[6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1)
Array([[ 1, 2],
[ 4, 5],
[ 7, 8],
[10, 11]], dtype=int32)
"""
return _fun_keep_unit_unary(lax.slice_in_dim, operand, start_index, limit_index, stride, axis, **kwargs)
@set_module_as('saiunit.math')
def index_in_dim(
operand: Union[Quantity, jax.typing.ArrayLike],
index: int,
axis: int = 0,
keepdims: bool = True,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Convenience wrapper around :func:`lax.slice` to perform int indexing.
This is effectively equivalent to ``operand[..., start_index:limit_index:stride]``
with the indexing applied on the specified axis.
Args:
operand: an array to index.
index: integer index
axis: the axis along which to apply the index (defaults to 0)
keepdims: boolean specifying whether the output array should preserve the
rank of the input (default=True)
Returns:
The subarray at the specified index.
Examples:
Here is a one-dimensional example:
>>> x = jnp.arange(4)
>>> lax.index_in_dim(x, 2)
Array([2], dtype=int32)
>>> lax.index_in_dim(x, 2, keepdims=False)
Array(2, dtype=int32)
Here are some two-dimensional examples:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> lax.index_in_dim(x, 1)
Array([[4, 5, 6, 7]], dtype=int32)
>>> lax.index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)
"""
return _fun_keep_unit_unary(lax.index_in_dim, operand, index, axis, keepdims, **kwargs)
@set_module_as('saiunit.math')
def dynamic_slice_ind_dim(
operand: Union[Quantity, jax.typing.ArrayLike],
start_index: jax.typing.ArrayLike,
slice_size: int,
axis: int = 0,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Convenience wrapper around :func:`lax.dynamic_slice` applied to one dimension.
This is roughly equivalent to the following Python indexing syntax applied
along the specified axis: ``operand[..., start_index:start_index + slice_size]``.
Args:
operand: an array to slice.
start_index: the (possibly dynamic) start index
slice_size: the static slice size
axis: the axis along which to apply the slice (defaults to 0)
Returns:
An array containing the slice.
Examples:
Here is a one-dimensional example:
>>> x = jnp.arange(5)
>>> dynamic_slice_ind_dim(x, 1, 3)
Array([1, 2, 3], dtype=int32)
Like `jax.lax.dynamic_slice`, out-of-bound slices will be clipped to the
valid range:
>>> dynamic_slice_ind_dim(x, 4, 3)
Array([2, 3, 4], dtype=int32)
Here is a two-dimensional example:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice_ind_dim(x, 1, 2, axis=1)
Array([[ 1, 2],
[ 5, 6],
[ 9, 10]], dtype=int32)
"""
return _fun_keep_unit_unary(lax.dynamic_slice_in_dim, operand, start_index, slice_size, axis, **kwargs)
@set_module_as('saiunit.math')
def dynamic_index_in_dim(
operand: Union[Quantity, jax.typing.ArrayLike],
index: int | jax.typing.ArrayLike,
axis: int = 0, keepdims: bool = True,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Convenience wrapper around dynamic_slice to perform int indexing.
This is roughly equivalent to the following Python indexing syntax applied
along the specified axis: ``operand[..., index]``.
Args:
operand: an array to slice.
index: the (possibly dynamic) start index
axis: the axis along which to apply the slice (defaults to 0)
keepdims: boolean specifying whether the output should have the same rank as
the input (default = True)
Returns:
An array containing the slice.
Examples:
Here is a one-dimensional example:
>>> x = jnp.arange(5)
>>> dynamic_index_in_dim(x, 1)
Array([1], dtype=int32)
>>> dynamic_index_in_dim(x, 1, keepdims=False)
Array(1, dtype=int32)
Here is a two-dimensional example:
>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)
"""
return _fun_keep_unit_unary(lax.dynamic_index_in_dim, operand, index, axis, keepdims, **kwargs)
@set_module_as('saiunit.math')
def dynamic_update_slice_in_dim(
operand: Union[Quantity, jax.typing.ArrayLike],
update: Union[Quantity, jax.typing.ArrayLike],
start_index: jax.typing.ArrayLike, axis: int,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Convenience wrapper around :func:`dynamic_update_slice` to update
a slice in a single ``axis``.
Args:
operand: an array to slice.
update: an array containing the new values to write onto `operand`.
start_index: a single scalar index
axis: the axis of the update.
Returns:
The updated array
Examples:
>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 1., 1., 0.], dtype=float32)
If the update slice is too large to fit in the array, the start
index will be adjusted to make it fit:
>>> dynamic_update_slice_in_dim(x, y, 3, axis=0)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice_in_dim(x, y, 5, axis=0)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
Here is an example of a two-dimensional slice update:
>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 4))
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]], dtype=float32)
Note that the shape of the additional axes in ``update`` need not
match the associated dimensions of the ``operand``:
>>> y = jnp.ones((2, 3))
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
[1., 1., 1., 0.],
[1., 1., 1., 0.],
[0., 0., 0., 0.]], dtype=float32)
"""
return _fun_keep_unit_binary(lax.dynamic_update_slice_in_dim, operand, update, start_index, axis, **kwargs)
@set_module_as('saiunit.math')
def dynamic_update_index_in_dim(
operand: Union[Quantity, jax.typing.ArrayLike],
update: Union[Quantity, jax.typing.ArrayLike],
index: jax.typing.ArrayLike,
axis: int,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
of size 1 in a single ``axis``.
Args:
operand: an array to slice.
update: an array containing the new values to write onto `operand`.
index: a single scalar index
axis: the axis of the update.
Returns:
The updated array
Examples:
>>> x = jnp.zeros(6)
>>> y = 1.0
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)
>>> y = jnp.array([1.0])
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)
If the specified index is out of bounds, the index will be clipped to the
valid range:
>>> dynamic_update_index_in_dim(x, y, 10, axis=0)
Array([0., 0., 0., 0., 0., 1.], dtype=float32)
Here is an example of a two-dimensional dynamic index update:
>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones(4)
>>> dynamic_update_index_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)
Note that the shape of the additional axes in ``update`` need not
match the associated dimensions of the ``operand``:
>>> y = jnp.ones((1, 3))
>>> dynamic_update_index_in_dim(x, y, 1, 0)
Array([[0., 0., 0., 0.],
[1., 1., 1., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)
"""
return _fun_keep_unit_binary(lax.dynamic_update_index_in_dim, operand, update, index, axis, **kwargs)
@set_module_as('saiunit.math')
def sort(
operand: Union[Quantity, jax.typing.ArrayLike] | Sequence[Union[Quantity, jax.typing.ArrayLike]],
dimension: int = -1,
is_stable: bool = True, num_keys: int = 1,
**kwargs,
) -> Union[Quantity, jax.Array] | Sequence[Union[Quantity, jax.Array]]:
"""Wraps XLA's `Sort
<https://www.tensorflow.org/xla/operation_semantics#sort>`_ operator.
For floating point inputs, -0.0 and 0.0 are treated as equivalent, and NaN values
are sorted to the end of the array. For complex inputs, the sort order is
lexicographic over the real and imaginary parts, with the real part primary.
Args:
operand : Array or sequence of arrays
dimension : integer dimension along which to sort. Default: -1.
is_stable : boolean specifying whether to use a stable sort. Default: True.
num_keys : number of operands to treat as sort keys. Default: 1.
For num_keys > 1, the sort order will be determined lexicographically using
the first `num_keys` arrays, with the first key being primary.
The remaining operands will be returned with the same permutation.
Returns:
operand : sorted version of the input or inputs.
"""
operand = maybe_custom_array_tree(operand)
# check if operand is a sequence
if isinstance(operand, Sequence):
# Convert quantities to mantissas, keeping track of units
mantissas = []
units = []
for op in operand:
if isinstance(op, Quantity):
mantissas.append(op.mantissa)
units.append(op.unit)
else:
mantissas.append(op)
units.append(None)
# Sort the mantissas
sorted_mantissas = lax.sort(mantissas, dimension, is_stable, num_keys, **kwargs)
# Convert back to quantities where applicable
output = []
for i, (mantissa, unit) in enumerate(zip(sorted_mantissas, units)):
if unit is not None:
output.append(maybe_decimal(Quantity(mantissa, unit=unit)))
else:
output.append(mantissa)
return output
else:
if isinstance(operand, Quantity):
return maybe_decimal(
Quantity(lax.sort(operand.mantissa, dimension, is_stable, num_keys, **kwargs), unit=operand.unit))
return lax.sort(operand, dimension, is_stable, num_keys, **kwargs)
@set_module_as('saiunit.math')
def sort_key_val(
keys: Union[Quantity, jax.typing.ArrayLike],
values: Union[Quantity, jax.typing.ArrayLike],
dimension: int = -1,
is_stable: bool = True,
**kwargs,
) -> tuple[Union[Quantity, jax.Array], Union[Quantity, jax.Array]]:
"""Sort ``keys`` along ``dimension`` and apply the same permutation to ``values``.
Parameters
----------
keys : array_like or Quantity
The array of keys to sort.
values : array_like or Quantity
The array of values to permute according to the sorted order of ``keys``.
dimension : int, optional
The dimension along which to sort. Default is -1.
is_stable : bool, optional
Whether to use a stable sort. Default is ``True``.
Returns
-------
sorted_keys : jax.Array or Quantity
The sorted keys. Preserves the unit of ``keys``.
sorted_values : jax.Array or Quantity
The values permuted to match the sorted order of ``keys``.
Preserves the unit of ``values``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> keys = jnp.array([3.0, 1.0, 2.0]) * u.meter
>>> vals = jnp.array([30, 10, 20])
>>> sorted_keys, sorted_vals = sulax.sort_key_val(keys, vals)
>>> sorted_keys.mantissa
Array([1., 2., 3.], dtype=float32)
>>> sorted_vals
Array([10, 20, 30], dtype=int32)
"""
keys = maybe_custom_array(keys)
values = maybe_custom_array(values)
if isinstance(keys, Quantity) and isinstance(values, Quantity):
k, v = lax.sort_key_val(keys.mantissa, values.mantissa, dimension, is_stable, **kwargs)
return maybe_decimal(Quantity(k, unit=keys.unit)), maybe_decimal(Quantity(v, unit=values.unit))
elif isinstance(keys, Quantity):
k, v = lax.sort_key_val(keys.mantissa, values, dimension, is_stable, **kwargs)
return maybe_decimal(Quantity(k, unit=keys.unit)), v
elif isinstance(values, Quantity):
k, v = lax.sort_key_val(keys, values.mantissa, dimension, is_stable, **kwargs)
return k, maybe_decimal(Quantity(v, unit=values.unit))
return lax.sort_key_val(keys, values, dimension, is_stable, **kwargs)
# math funcs keep unit (unary)
@set_module_as('saiunit.math')
def neg(
x: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
r"""Elementwise negation: :math:`-x`.
Parameters
----------
x : array_like or Quantity
Input array.
Returns
-------
result : jax.Array or Quantity
The negated array. If ``x`` is a ``Quantity``, the result
preserves the same unit.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([1.0, -2.0, 3.0]) * u.meter
>>> result = sulax.neg(q)
>>> result.mantissa
Array([-1., 2., -3.], dtype=float32)
>>> result.unit
meter
"""
return _fun_keep_unit_unary(lax.neg, x, **kwargs)
@set_module_as('saiunit.math')
def cummax(
operand: Union[Quantity, jax.typing.ArrayLike],
axis: int = 0,
reverse: bool = False,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Compute a cumulative maximum along ``axis``.
Parameters
----------
operand : array_like or Quantity
Input array.
axis : int, optional
The axis along which to compute the cumulative maximum.
Default is 0.
reverse : bool, optional
If ``True``, compute the cumulative maximum in reverse.
Default is ``False``.
Returns
-------
result : jax.Array or Quantity
The cumulative maximum array. Preserves the unit of ``operand``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([3.0, 1.0, 4.0, 1.0]) * u.second
>>> result = sulax.cummax(q)
>>> result.mantissa
Array([3., 3., 4., 4.], dtype=float32)
"""
return _fun_keep_unit_unary(lax.cummax, operand, axis, reverse, **kwargs)
@set_module_as('saiunit.math')
def cummin(
operand: Union[Quantity, jax.typing.ArrayLike],
axis: int = 0,
reverse: bool = False,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Compute a cumulative minimum along ``axis``.
Parameters
----------
operand : array_like or Quantity
Input array.
axis : int, optional
The axis along which to compute the cumulative minimum.
Default is 0.
reverse : bool, optional
If ``True``, compute the cumulative minimum in reverse.
Default is ``False``.
Returns
-------
result : jax.Array or Quantity
The cumulative minimum array. Preserves the unit of ``operand``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([3.0, 1.0, 4.0, 1.0]) * u.second
>>> result = sulax.cummin(q)
>>> result.mantissa
Array([3., 1., 1., 1.], dtype=float32)
"""
return _fun_keep_unit_unary(lax.cummin, operand, axis, reverse, **kwargs)
@set_module_as('saiunit.math')
def cumsum(
operand: Union[Quantity, jax.typing.ArrayLike],
axis: int = 0,
reverse: bool = False,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Compute a cumulative sum along ``axis``.
Parameters
----------
operand : array_like or Quantity
Input array.
axis : int, optional
The axis along which to compute the cumulative sum.
Default is 0.
reverse : bool, optional
If ``True``, compute the cumulative sum in reverse.
Default is ``False``.
Returns
-------
result : jax.Array or Quantity
The cumulative sum array. Preserves the unit of ``operand``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([1.0, 2.0, 3.0, 4.0]) * u.meter
>>> result = sulax.cumsum(q)
>>> result.mantissa
Array([ 1., 3., 6., 10.], dtype=float32)
"""
return _fun_keep_unit_unary(lax.cumsum, operand, axis, reverse, **kwargs)
def _fun_lax_scatter(
fun: Callable,
operand,
scatter_indices,
updates,
dimension_numbers,
indices_are_sorted,
unique_indices,
mode
) -> Union[Quantity, jax.Array]:
operand = maybe_custom_array(operand)
updates = maybe_custom_array(updates)
if isinstance(operand, Quantity) and isinstance(updates, Quantity):
if not has_same_unit(operand, updates):
raise TypeError(
f'operand(unit:{operand.unit}) and updates(unit:{updates.unit}) must have the same unit.'
)
return maybe_decimal(Quantity(fun(operand.mantissa, scatter_indices, updates.mantissa, dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode), unit=operand.unit))
elif isinstance(operand, Quantity) or isinstance(updates, Quantity):
raise TypeError(
f'operand and updates should both be `Quantity` or Array, now we got {type(operand)} and {type(updates)}')
else:
return fun(operand, scatter_indices, updates, dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode)
@set_module_as('saiunit.math')
def scatter(
operand: Union[Quantity, jax.typing.ArrayLike],
scatter_indices: jax.typing.ArrayLike,
updates: jax.typing.ArrayLike,
dimension_numbers: jax.lax.ScatterDimensionNumbers,
*,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | jax.lax.GatherScatterMode | None = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Scatter-update operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where updates
replace values from `operand`.
If multiple updates are performed to the same index of operand, they may be
applied in any order.
:func:`scatter` is a low-level operator with complicated semantics, and most
JAX users will never need to call it directly. Instead, you should prefer using
:func:`jax.numpy.ndarray.at` for more familiary NumPy-style indexing syntax.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
relate.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
implementation-defined.
Returns:
An array containing the sum of `operand` and the scattered updates.
Examples:
As mentioned above, you should basically never use :func:`scatter` directly,
and instead perform scatter-style operations using NumPy-style indexing
expressions via :attr:`jax.numpy.ndarray.at`.
Here is and example of updating entries in an array using :attr:`jax.numpy.ndarray.at`,
which lowers to an XLA Scatter operation:
>>> x = jnp.zeros(5)
>>> indices = jnp.array([1, 2, 4])
>>> values = jnp.array([2.0, 3.0, 4.0])
>>> x.at[indices].set(values)
Array([0., 2., 3., 0., 4.], dtype=float32)
This syntax also supports several of the optional arguments to :func:`scatter`,
for example:
>>> x.at[indices].set(values, indices_are_sorted=True, mode='promise_in_bounds')
Array([0., 2., 3., 0., 4.], dtype=float32)
By comparison, here is the equivalent function call using :func:`scatter` directly,
which is not something typical users should ever need to do:
>>> lax.scatter(x, indices[:, None], values,
... dimension_numbers=lax.ScatterDimensionNumbers(
... update_window_dims=(),
... inserted_window_dims=(0,),
... scatter_dims_to_operand_dims=(0,)),
... indices_are_sorted=True,
... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
Array([0., 2., 3., 0., 4.], dtype=float32)
"""
return _fun_lax_scatter(lax.scatter, operand, scatter_indices, updates, dimension_numbers, indices_are_sorted,
unique_indices, mode, **kwargs)
@set_module_as('saiunit.math')
def scatter_add(
operand: Union[Quantity, jax.typing.ArrayLike],
scatter_indices: jax.typing.ArrayLike,
updates: jax.typing.ArrayLike,
dimension_numbers: jax.lax.ScatterDimensionNumbers,
*,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | jax.lax.GatherScatterMode | None = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Scatter-add operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
addition is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
the familiar NumPy indexing syntax.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `scatter_indices`, `updates` and the output
relate.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
implementation-defined.
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
return _fun_lax_scatter(lax.scatter_add, operand, scatter_indices, updates, dimension_numbers, indices_are_sorted,
unique_indices, mode, **kwargs)
@set_module_as('saiunit.math')
def scatter_sub(
operand: Union[Quantity, jax.typing.ArrayLike],
scatter_indices: jax.typing.ArrayLike,
updates: jax.typing.ArrayLike,
dimension_numbers: jax.lax.ScatterDimensionNumbers,
*,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | jax.lax.GatherScatterMode | None = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Scatter-sub operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
subtraction is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
the familiar NumPy indexing syntax.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `scatter_indices`, `updates` and the output
relate.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
implementation-defined.
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
return _fun_lax_scatter(lax.scatter_sub, operand, scatter_indices, updates, dimension_numbers, indices_are_sorted,
unique_indices, mode, **kwargs)
@set_module_as('saiunit.math')
def scatter_mul(
operand: Union[Quantity, jax.typing.ArrayLike],
scatter_indices: jax.typing.ArrayLike,
updates: jax.typing.ArrayLike,
dimension_numbers: jax.lax.ScatterDimensionNumbers,
*,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | jax.lax.GatherScatterMode | None = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Scatter-multiply operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
multiplication is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
the familiar NumPy indexing syntax.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `scatter_indices`, `updates` and the output
relate.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
implementation-defined.
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
return _fun_lax_scatter(lax.scatter_mul, operand, scatter_indices, updates, dimension_numbers, indices_are_sorted,
unique_indices, mode, **kwargs)
@set_module_as('saiunit.math')
def scatter_min(
operand: Union[Quantity, jax.typing.ArrayLike],
scatter_indices: jax.typing.ArrayLike,
updates: jax.typing.ArrayLike,
dimension_numbers: jax.lax.ScatterDimensionNumbers,
*,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | jax.lax.GatherScatterMode | None = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Scatter-min operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
the `min` function is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
the familiar NumPy indexing syntax.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `scatter_indices`, `updates` and the output
relate.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
implementation-defined.
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
return _fun_lax_scatter(lax.scatter_min, operand, scatter_indices, updates, dimension_numbers, indices_are_sorted,
unique_indices, mode, **kwargs)
@set_module_as('saiunit.math')
def scatter_max(
operand: Union[Quantity, jax.typing.ArrayLike],
scatter_indices: jax.typing.ArrayLike,
updates: jax.typing.ArrayLike,
dimension_numbers: jax.lax.ScatterDimensionNumbers,
*,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | jax.lax.GatherScatterMode | None = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Scatter-max operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where
the `max` function is used to combine updates and values from `operand`.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
the familiar NumPy indexing syntax.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
updates: the updates that should be scattered onto `operand`.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `scatter_indices`, `updates` and the output
relate.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
implementation-defined.
Returns:
An array containing the sum of `operand` and the scattered updates.
"""
return _fun_lax_scatter(lax.scatter_max, operand, scatter_indices, updates, dimension_numbers, indices_are_sorted,
unique_indices, mode, **kwargs)
@set_module_as('saiunit.math')
def scatter_apply(
operand: Union[Quantity, jax.typing.ArrayLike],
scatter_indices: jax.typing.ArrayLike,
func: Callable,
dimension_numbers: jax.lax.ScatterDimensionNumbers,
*,
update_shape: Shape = (),
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | jax.lax.GatherScatterMode | None = None,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Scatter-apply operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where values
from ``operand`` are replaced with ``func(operand)``, with duplicate indices
resulting in multiple applications of ``func``.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
the familiar NumPy indexing syntax.
Note that in the current implementation, ``scatter_apply`` is not compatible
with automatic differentiation.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
func: unary function that will be applied at each index.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
relate.
update_shape: the shape of the updates at the given indices.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the elements to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends. JAX does not check this promise: if the updated elements
overlap when ``unique_indices`` is ``True`` the behavior is undefined.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
implementation-defined.
Returns:
An array containing the result of applying `func` to `operand` at the given indices.
"""
operand = maybe_custom_array(operand)
if isinstance(operand, Quantity):
return maybe_decimal(Quantity(lax.scatter_apply(operand.mantissa, scatter_indices, func, dimension_numbers,
update_shape=update_shape,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode, **kwargs), unit=operand.unit))
else:
return lax.scatter_apply(operand, scatter_indices, func, dimension_numbers,
update_shape=update_shape,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode, **kwargs)
# math funcs keep unit (binary)
@set_module_as('saiunit.math')
def complex(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
r"""Elementwise make complex number: :math:`x + jy`.
Build a complex number from real and imaginary parts.
Parameters
----------
x : array_like or Quantity
The real part.
y : array_like or Quantity
The imaginary part. Must have the same unit as ``x``.
Returns
-------
result : jax.Array or Quantity
The complex array. If inputs are ``Quantity``, the result
preserves the same unit.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> real = jnp.array([1.0, 2.0]) * u.volt
>>> imag = jnp.array([3.0, 4.0]) * u.volt
>>> result = sulax.complex(real, imag)
>>> result.mantissa
Array([1.+3.j, 2.+4.j], dtype=complex64)
"""
return _fun_keep_unit_binary(lax.complex, x, y, **kwargs)
@set_module_as('saiunit.math')
def pad(
operand: Union[Quantity, jax.typing.ArrayLike],
padding_value: Union[Quantity, jax.typing.ArrayLike],
padding_config: Sequence[tuple[int, int, int]],
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Applies low, high, and/or interior padding to an array.
Wraps XLA's `Pad
<https://www.tensorflow.org/xla/operation_semantics#pad>`_
operator.
Args:
operand: an array to be padded.
padding_value: the value to be inserted as padding. Must have the same dtype
as ``operand``.
padding_config: a sequence of ``(low, high, interior)`` tuples of integers,
giving the amount of low, high, and interior (dilation) padding to insert
in each dimension.
Returns:
The ``operand`` array with padding value ``padding_value`` inserted in each
dimension according to the ``padding_config``.
"""
return _fun_keep_unit_binary(lax.pad, operand, padding_value, padding_config, **kwargs)
@set_module_as('saiunit.math')
def sub(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
r"""Elementwise subtraction: :math:`x - y`.
Parameters
----------
x : array_like or Quantity
The minuend.
y : array_like or Quantity
The subtrahend. Must have the same unit as ``x``.
Returns
-------
result : jax.Array or Quantity
The difference. Preserves the unit of the inputs.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> a = jnp.array([5.0, 8.0]) * u.meter
>>> b = jnp.array([1.0, 3.0]) * u.meter
>>> result = sulax.sub(a, b)
>>> result.mantissa
Array([4., 5.], dtype=float32)
"""
return _fun_keep_unit_binary(lax.sub, x, y, **kwargs)
# type conversion
@set_module_as('saiunit.math')
def convert_element_type(
operand: Union[Quantity, jax.typing.ArrayLike],
new_dtype: jax.typing.DTypeLike,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Elementwise cast.
Wraps XLA's `ConvertElementType
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
operator, which performs an elementwise conversion from one type to another.
Similar to a C++ `static_cast`.
Args:
operand: an array or scalar value to be cast.
new_dtype: a NumPy dtype representing the target type.
Returns:
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
"""
return _fun_keep_unit_unary(lax.convert_element_type, operand, new_dtype, **kwargs)
@set_module_as('saiunit.math')
def bitcast_convert_type(
operand: Union[Quantity, jax.typing.ArrayLike],
new_dtype: jax.typing.DTypeLike,
**kwargs,
) -> Union[Quantity, jax.Array]:
"""Elementwise bitcast.
Wraps XLA's `BitcastConvertType
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
operator, which performs a bit cast from one type to another.
The output shape depends on the size of the input and output dtypes::
if new_dtype.itemsize == operand.dtype.itemsize:
output_shape = operand.shape
if new_dtype.itemsize < operand.dtype.itemsize:
output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize)
if new_dtype.itemsize > operand.dtype.itemsize:
assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize
output_shape = operand.shape[:-1]
Parameters
----------
operand : array_like, Quantity
An array or scalar value to be cast.
new_dtype : dtype
The new type. Should be a NumPy dtype.
Returns
-------
out : Quantity or jax.Array
An array of shape ``output_shape`` (see above) and type ``new_dtype``,
constructed from the same bits as ``operand``.
"""
return _fun_keep_unit_unary(lax.bitcast_convert_type, operand, new_dtype, **kwargs)
# math funcs keep unit (n-ary)
@set_module_as('saiunit.math')
def clamp(
min: Union[Quantity, jax.typing.ArrayLike],
x: Union[Quantity, jax.typing.ArrayLike],
max: Union[Quantity, jax.typing.ArrayLike],
**kwargs,
) -> Union[Quantity, jax.Array]:
r"""Elementwise clamp.
Returns :math:`\mathrm{clamp}(x) = \begin{cases}
\mathit{min} & \text{if } x < \mathit{min},\\
\mathit{max} & \text{if } x > \mathit{max},\\
x & \text{otherwise}
\end{cases}`.
"""
min = maybe_custom_array(min)
x = maybe_custom_array(x)
max = maybe_custom_array(max)
if all(isinstance(i, Quantity) for i in (min, x, max)):
unit = min.unit
return maybe_decimal(Quantity(lax.clamp(min.mantissa, x.to_decimal(unit), max.to_decimal(unit), **kwargs), unit=unit))
elif all(isinstance(i, (jax.Array, np.ndarray, np.bool_, np.number, bool, int, float, builtins.complex)) for i in
(min, x, max)):
return lax.clamp(min, x, max, **kwargs)
else:
raise TypeError('All inputs must be Quantity or jax.typing.ArrayLike')
# math funcs keep unit (return Quantity and index)
@set_module_as('saiunit.math')
def approx_max_k(
operand: Union[Quantity, jax.typing.ArrayLike],
k: int,
reduction_dimension: int = -1,
recall_target: float = 0.95,
reduction_input_size_override: int = -1,
aggregate_to_topk: bool = True,
**kwargs,
) -> tuple[Union[Quantity, jax.Array], jax.typing.ArrayLike]:
"""Returns max ``k`` values and their indices of the ``operand`` in an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.
Args:
operand : Array to search for max-k. Must be a floating number type.
k : Specifies the number of max-k.
reduction_dimension : Integer dimension along which to search. Default: -1.
recall_target : Recall target for the approximation.
reduction_input_size_override : When set to a positive value, it overrides
the size determined by ``operand[reduction_dim]`` for evaluating the
recall. This option is useful when the given ``operand`` is only a subset
of the overall computation in SPMD or distributed pipelines, where the
true input size cannot be deferred by the operand shape.
aggregate_to_topk : When true, aggregates approximate results to the top-k
in sorted order. When false, returns the approximate results unsorted. In
this case, the number of the approximate results is implementation defined
and is greater or equal to the specified ``k``.
Returns:
Tuple of two arrays. The arrays are the max ``k`` values and the
corresponding indices along the ``reduction_dimension`` of the input
``operand``. The arrays' dimensions are the same as the input ``operand``
except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
where the size is implementation-defined.
We encourage users to wrap ``approx_max_k`` with jit. See the following
example for maximal inner production search (MIPS):
>>> import functools
>>> import jax
>>> import numpy as np
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def mips(qy, db, k=10, recall_target=0.95):
... dists = jax.lax.dot(qy, db.transpose())
... # returns (f32[qy_size, k], i32[qy_size, k])
... return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> dot_products, neighbors = mips(qy, db, k=10)
"""
operand = maybe_custom_array(operand)
if isinstance(operand, Quantity):
r = lax.approx_max_k(operand.mantissa, k, reduction_dimension, recall_target, reduction_input_size_override,
aggregate_to_topk, **kwargs)
return maybe_decimal(Quantity(r[0], unit=operand.unit)), r[1]
return lax.approx_max_k(operand, k, reduction_dimension, recall_target, reduction_input_size_override,
aggregate_to_topk, **kwargs)
@set_module_as('saiunit.math')
def approx_min_k(
operand: Union[Quantity, jax.typing.ArrayLike],
k: int,
reduction_dimension: int = -1,
recall_target: float = 0.95,
reduction_input_size_override: int = -1,
aggregate_to_topk: bool = True,
**kwargs,
) -> tuple[Union[Quantity, jax.Array], jax.typing.ArrayLike]:
"""Returns min ``k`` values and their indices of the ``operand`` in an approximate manner.
See https://arxiv.org/abs/2206.14286 for the algorithm details.
Args:
operand : Array to search for min-k. Must be a floating number type.
k : Specifies the number of min-k.
reduction_dimension: Integer dimension along which to search. Default: -1.
recall_target: Recall target for the approximation.
reduction_input_size_override : When set to a positive value, it overrides
the size determined by ``operand[reduction_dim]`` for evaluating the
recall. This option is useful when the given operand is only a subset of
the overall computation in SPMD or distributed pipelines, where the true
input size cannot be deferred by the ``operand`` shape.
aggregate_to_topk : When true, aggregates approximate results to the top-k
in sorted order. When false, returns the approximate results unsorted. In
this case, the number of the approximate results is implementation defined
and is greater or equal to the specified ``k``.
Returns:
Tuple of two arrays. The arrays are the least ``k`` values and the
corresponding indices along the ``reduction_dimension`` of the input
``operand``. The arrays' dimensions are the same as the input ``operand``
except for the ``reduction_dimension``: when ``aggregate_to_topk`` is true,
the reduction dimension is ``k``; otherwise, it is greater equals to ``k``
where the size is implementation-defined.
We encourage users to wrap ``approx_min_k`` with jit. See the following example
for nearest neighbor search over the squared l2 distance:
>>> import functools
>>> import jax
>>> import numpy as np
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def l2_ann(qy, db, half_db_norms, k=10, recall_target=0.95):
... dists = half_db_norms - jax.lax.dot(qy, db.transpose())
... return jax.lax.approx_min_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> half_db_norm_sq = jax.numpy.linalg.norm(db, axis=1)**2 / 2
>>> dists, neighbors = l2_ann(qy, db, half_db_norm_sq, k=10)
In the example above, we compute ``db^2/2 - dot(qy, db^T)`` instead of
``qy^2 - 2 dot(qy, db^T) + db^2`` for performance reason. The former uses less
arithmetic and produces the same set of neighbors.
"""
operand = maybe_custom_array(operand)
if isinstance(operand, Quantity):
r = lax.approx_min_k(operand.mantissa, k, reduction_dimension, recall_target, reduction_input_size_override,
aggregate_to_topk, **kwargs)
return maybe_decimal(Quantity(r[0], unit=operand.unit)), r[1]
return lax.approx_min_k(operand, k, reduction_dimension, recall_target, reduction_input_size_override,
aggregate_to_topk, **kwargs)
@set_module_as('saiunit.math')
def top_k(
operand: Union[Quantity, jax.typing.ArrayLike],
k: int,
**kwargs,
) -> tuple[Union[Quantity, jax.Array], jax.typing.ArrayLike]:
"""Returns top ``k`` values and their indices along the last axis of ``operand``.
Args:
operand: N-dimensional array of non-complex type.
k: integer specifying the number of top entries.
Returns:
A tuple ``(values, indices)`` where
- ``values`` is an array containing the top k values along the last axis.
- ``indices`` is an array containing the indices corresponding to values.
Examples:
Find the largest three values, and their indices, within an array:
>>> x = jnp.array([9., 3., 6., 4., 10.])
>>> values, indices = jax.lax.top_k(x, 3)
>>> values
Array([10., 9., 6.], dtype=float32)
>>> indices
Array([4, 0, 2], dtype=int32)
"""
operand = maybe_custom_array(operand)
if isinstance(operand, Quantity):
r = lax.top_k(operand.mantissa, k, **kwargs)
return maybe_decimal(Quantity(r[0], unit=operand.unit)), r[1]
return lax.top_k(operand, k, **kwargs)
# broadcasting arrays
[docs]
def broadcast(
operand: Union[Quantity, jax.typing.ArrayLike],
sizes: Sequence[int]
) -> Union[Quantity, jax.Array]:
"""Broadcast an array by adding new leading dimensions.
Parameters
----------
operand : array_like or Quantity
The input array.
sizes : sequence of int
Sizes of new leading dimensions to prepend to the array shape.
Returns
-------
result : jax.Array or Quantity
The broadcasted array. Preserves the unit of ``operand``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([1.0, 2.0]) * u.second
>>> result = sulax.broadcast(q, sizes=(3,))
>>> result.mantissa.shape
(3, 2)
>>> result.unit
second
"""
return _fun_keep_unit_unary(lax.broadcast, operand, sizes)
[docs]
def broadcast_in_dim(
operand: Union[Quantity, jax.typing.ArrayLike],
shape: Shape,
broadcast_dimensions: Sequence[int]
) -> Union[Quantity, jax.Array]:
"""Broadcast an array into a target shape (XLA BroadcastInDim).
Parameters
----------
operand : array_like or Quantity
The input array.
shape : Shape
The target shape for the broadcast.
broadcast_dimensions : sequence of int
Mapping from operand dimensions to target dimensions: dimension *i*
of the operand becomes dimension ``broadcast_dimensions[i]``
of the result.
Returns
-------
result : jax.Array or Quantity
The broadcasted array. Preserves the unit of ``operand``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([1.0, 2.0]) * u.meter
>>> result = sulax.broadcast_in_dim(q, shape=(3, 2), broadcast_dimensions=(1,))
>>> result.mantissa.shape
(3, 2)
"""
return _fun_keep_unit_unary(lax.broadcast_in_dim, operand, shape, broadcast_dimensions)
[docs]
def broadcast_to_rank(
x: Union[Quantity, jax.typing.ArrayLike],
rank: int
) -> Union[Quantity, jax.Array]:
"""Add leading dimensions of size 1 to give ``x`` rank ``rank``.
Parameters
----------
x : array_like or Quantity
The input array.
rank : int
The desired rank of the output.
Returns
-------
result : jax.Array or Quantity
The array with added leading dimensions. Preserves the unit of ``x``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([1.0, 2.0]) * u.meter
>>> result = sulax.broadcast_to_rank(q, rank=3)
>>> result.mantissa.shape
(1, 1, 2)
"""
return _fun_keep_unit_unary(lax.broadcast_to_rank, x, rank)