# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import functools
import numbers
import operator
import re
from collections.abc import Callable, Sequence
from copy import deepcopy
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from ._base_dimension import Dimension, UnitMismatchError, _is_tracer
from ._base_getters import (
get_dim,
fail_for_dimension_mismatch,
maybe_decimal,
_to_quantity,
unit_scale_align_to_first,
)
from ._base_unit import Unit, UNITLESS
from ._misc import maybe_custom_array_tree
from ._sparse_base import SparseMatrix
__all__ = [
'Quantity',
'compatible_with_equinox',
]
# ---------------------------------------------------------------------------
# Module-level type aliases and globals
# ---------------------------------------------------------------------------
StaticScalar = (
np.bool_ | np.number | # NumPy scalar types
bool | int | float | complex # Python scalar types
)
PyTree = Any
_all_slice = slice(None, None, None)
compat_with_equinox = False
[docs]
def compatible_with_equinox(mode: bool = True):
"""
Enable or disable compatibility with the Equinox library.
When enabled, ``Quantity`` objects interact correctly with Equinox
transformations such as those used in
`unit-aware diffrax <https://github.com/chaoming0625/diffrax>`_.
Parameters
----------
mode : bool, optional
If ``True`` (default), enable Equinox compatibility.
If ``False``, disable it.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> u.compatible_with_equinox(True) # enable
>>> u.compatible_with_equinox(False) # disable
See Also
--------
Quantity : The core physical-quantity class affected by this setting.
"""
global compat_with_equinox
compat_with_equinox = mode
# ---------------------------------------------------------------------------
# Wrapping functions
# ---------------------------------------------------------------------------
def _wrap_function_keep_unit(func):
"""
Returns a new function that wraps the given function `func` so that it
keeps the dimensions of its input. Arrays are transformed to
unitless jax numpy arrays before calling `func`, the output is a array
with the original dimensions re-attached.
These transformations apply only to the very first argument, all
other arguments are ignored/untouched, allowing to work functions like
``sum`` to work as expected with additional ``axis`` etc. arguments.
"""
@functools.wraps(func)
def f(x: 'Quantity', *args, **kwds): # pylint: disable=C0111
return Quantity(func(x.mantissa, *args, **kwds), unit=x.unit)
f._arg_units = [None]
f._return_unit = lambda u: u
f._do_not_run_doctests = True
return f
def _wrap_function_change_unit(func, unit_fun):
"""
Returns a new function that wraps the given function `func` so that it
changes the dimensions of its input. Arrays are transformed to
unitless jax numpy arrays before calling `func`, the output is a array
with the original dimensions passed through the function
`unit_fun`. A typical use would be a ``sqrt`` function that uses
``lambda d: d ** 0.5`` as ``unit_fun``.
These transformations apply only to the very first argument, all
other arguments are ignored/untouched.
"""
@functools.wraps(func)
def f(x, *args, **kwds): # pylint: disable=C0111
assert isinstance(x, Quantity), "Only Quantity objects can be passed to this function"
x = x.factorless()
return maybe_decimal(Quantity(func(x.mantissa, *args, **kwds), unit=unit_fun(x.unit, x.unit)))
f._arg_units = [None]
f._return_unit = unit_fun
f._do_not_run_doctests = True
return f
def _wrap_function_remove_unit(func):
"""
Returns a new function that wraps the given function `func` so that it
removes any dimensions from its input. Useful for functions that are
returning integers (indices) or booleans, irrespective of the datatype
contained in the array.
These transformations apply only to the very first argument, all
other arguments are ignored/untouched.
"""
@functools.wraps(func)
def f(x, *args, **kwds): # pylint: disable=C0111
assert isinstance(x, Quantity), "Only Quantity objects can be passed to this function"
return func(x.mantissa, *args, **kwds)
f._arg_units = [None]
f._return_unit = 1
f._do_not_run_doctests = True
return f
# ---------------------------------------------------------------------------
# List processing helpers
# ---------------------------------------------------------------------------
def _zoom_values_with_units(
values: Sequence[jax.typing.ArrayLike],
units: Sequence[Unit]
):
"""
Zoom values with units.
Parameters
----------
values : `Array`
The values to zoom.
units : `Array`
The units to use for zooming.
Returns
-------
zoomed_values : `Array`
The zoomed values.
"""
assert len(values) == len(units), "The number of values and units must be the same"
values = list(values)
first_unit = units[0]
for i in range(1, len(values)):
if not units[i].has_same_magnitude(first_unit):
values[i] = values[i] * (units[i].magnitude / first_unit.magnitude)
return values
def _check_units_and_collect_values(lst) -> tuple[jax.typing.ArrayLike, Unit]:
units = []
values = []
for item in lst:
if isinstance(item, (list, tuple)):
val, unit = _check_units_and_collect_values(item)
values.append(val)
if unit != UNITLESS:
units.append(unit)
elif isinstance(item, Quantity):
values.append(item.mantissa)
units.append(item.unit)
elif isinstance(item, Unit):
values.append(1)
units.append(item)
else:
values.append(item)
units.append(None)
if len(units):
# Normalize None (plain scalars) to UNITLESS so they are
# compatible with explicitly unitless Quantity values.
units = [UNITLESS if u is None else u for u in units]
first_unit = units[0]
if not all(first_unit.has_same_dim(unit) for unit in units):
raise TypeError(f"All elements must have the same units, but got {units}")
return jnp.asarray(_zoom_values_with_units(values, units)), first_unit
else:
return jnp.asarray(values), UNITLESS
def _process_list_with_units(value: list) -> tuple[jax.typing.ArrayLike, Unit]:
values, unit = _check_units_and_collect_values(value)
return values, unit
def _element_not_quantity(x):
assert not isinstance(x, Quantity), f"Expected not a Quantity object, but got {x}"
return x
# ---------------------------------------------------------------------------
# Pickle helper
# ---------------------------------------------------------------------------
def _quantity_with_unit(mantissa, unit):
"""Private reconstruction helper for Quantity pickling.
"""
return Quantity(mantissa, unit=unit)
_quantity_with_unit.__module__ = 'saiunit._base_quantity'
# ---------------------------------------------------------------------------
# Quantity class
# ---------------------------------------------------------------------------
@register_pytree_node_class
class Quantity:
"""
A numerical value paired with a physical unit.
``Quantity`` is the central data structure in ``saiunit``. It stores a
*mantissa* (the raw numerical data, typically a JAX array) together with a
:class:`Unit` that describes the physical dimensions and scale. Arithmetic
on ``Quantity`` objects automatically tracks and checks units, raising
:class:`UnitMismatchError` when incompatible quantities are combined.
``Quantity`` is registered as a JAX pytree, so it works transparently with
``jax.jit``, ``jax.grad``, ``jax.vmap``, and other JAX transformations.
Parameters
----------
mantissa : array_like, number, Unit, or Quantity
The numerical value(s). If a :class:`Unit` is passed, the mantissa is
set to ``1.0`` and that unit is adopted. If a :class:`Quantity` is
passed, its mantissa and unit are used (converted to *unit* when
given).
unit : Unit, optional
The physical unit. Defaults to ``UNITLESS``.
dtype : dtype, optional
If provided, the mantissa is cast to this dtype on construction.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> # Scalar with unit
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q
Quantity(3., "mV")
>>> # Array with unit via multiplication shorthand
>>> arr = jnp.array([1.0, 2.0, 3.0]) * u.mV
>>> arr.shape
(3,)
>>> # From a Unit object directly
>>> u.Quantity(u.metre)
Quantity(1., "m")
See Also
--------
Unit : Represents a physical unit (dimension + scale).
compatible_with_equinox : Toggle Equinox interoperability.
"""
__module__ = "saiunit"
__slots__ = ('_mantissa', '_unit')
__array_priority__ = 1000
_mantissa: jax.Array | np.ndarray
_unit: Unit
def __class_getitem__(cls, item: Unit | str) -> type['Quantity']:
"""Enable ``Quantity[unit]`` and ``Quantity["physical_type"]`` annotations.
Returns a type that supports ``isinstance`` checks and can be used as
a type annotation.
Parameters
----------
item : Unit or str
A :class:`Unit` instance (e.g. ``u.meter``) or a string naming a
physical type (e.g. ``"length"``, ``"speed"``).
Returns
-------
type
A class supporting ``isinstance(quantity, Quantity[unit])``.
Examples
--------
>>> import saiunit as u
>>> x = 2.0 * u.kmeter
>>> isinstance(x, u.Quantity[u.meter]) # dimension check
True
>>> isinstance(x, u.Quantity["length"]) # physical type check
True
>>> isinstance(x, u.Quantity["mass"]) # wrong dimension
False
Notes
-----
Some static analyzers may report warnings for
``isinstance(x, Quantity["..."])`` because they interpret this syntax
as parameterized generics. For IDE-safe runtime checks, use
:func:`saiunit.typing.quantity_type`.
"""
from .typing import _make_annotated_quantity_type
return _make_annotated_quantity_type(item)
def __init__(
self,
mantissa: PyTree | Unit,
unit: 'Unit | jax.typing.ArrayLike | str | None' = UNITLESS,
dtype: jax.typing.DTypeLike | None = None,
):
with jax.ensure_compile_time_eval(): # inside JIT, this can avoid to trace the constant mantissa value
# String-based unit: Quantity(1.0, "mV")
if isinstance(unit, str):
from ._base_unit import parse_unit
unit = parse_unit(unit)
# Handle custom arrays in the mantissa tree structure
mantissa = maybe_custom_array_tree(mantissa)
if isinstance(mantissa, Unit):
if unit is not UNITLESS:
raise ValueError(
"Cannot create a Quantity object with a unit and a mantissa that is a Unit object.")
unit = mantissa
mantissa = 1.
if isinstance(mantissa, (list, tuple)):
mantissa, new_unit = _process_list_with_units(mantissa)
if unit is UNITLESS:
unit = new_unit
elif new_unit != UNITLESS:
if not new_unit.has_same_dim(unit):
raise TypeError(f"All elements must have the same unit. But got {unit} != {new_unit}")
if not new_unit.has_same_magnitude(unit):
mantissa = mantissa * (new_unit.magnitude / unit.magnitude)
mantissa = jnp.array(mantissa, dtype=dtype)
# array mantissa
elif isinstance(mantissa, Quantity):
if unit is UNITLESS:
unit = mantissa.unit
elif not unit.has_same_dim(mantissa.unit):
raise ValueError("Cannot create a Quantity object with a different unit.")
mantissa = mantissa.in_unit(unit)
mantissa = mantissa.mantissa
elif isinstance(mantissa, (np.ndarray, jax.Array)):
if dtype is not None:
mantissa = jnp.array(mantissa, dtype=dtype)
# skip 'asarray' if dtype is not provided
elif isinstance(mantissa, (jnp.number, numbers.Number)):
pass # keep as-is; jnp.array conversion deferred to use-site
else:
pass # keep as-is for other pytree types
# mantissa
self._mantissa = mantissa
# dimension
self._unit = unit
@property
def at(self):
"""
Helper property for index update functionality.
The ``at`` property provides a functionally pure equivalent of in-place
array modifications.
In particular:
============================== ================================
Alternate syntax Equivalent In-place expression
============================== ================================
``x = x.at[idx].set(y)`` ``x[idx] = y``
``x = x.at[idx].add(y)`` ``x[idx] += y``
``x = x.at[idx].multiply(y)`` ``x[idx] *= y``
``x = x.at[idx].divide(y)`` ``x[idx] /= y``
``x = x.at[idx].power(y)`` ``x[idx] **= y``
``x = x.at[idx].min(y)`` ``x[idx] = minimum(x[idx], y)``
``x = x.at[idx].max(y)`` ``x[idx] = maximum(x[idx], y)``
``x = x.at[idx].apply(ufunc)`` ``ufunc.at(x, idx)``
``x = x.at[idx].get()`` ``x = x[idx]``
============================== ================================
None of the ``x.at`` expressions modify the original ``x``; instead they return
a modified copy of ``x``. However, inside a :py:func:`~jax.jit` compiled function,
expressions like :code:`x = x.at[idx].set(y)` are guaranteed to be applied in-place.
Unlike NumPy in-place operations such as :code:`x[idx] += y`, if multiple
indices refer to the same location, all updates will be applied (NumPy would
only apply the last update, rather than applying all updates.) The order
in which conflicting updates are applied is implementation-defined and may be
nondeterministic (e.g., due to concurrency on some hardware platforms).
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound
index semantics can be specified via the ``mode`` parameter (see below).
Arguments
---------
mode : str
Specify out-of-bound indexing mode. Options are:
- ``"promise_in_bounds"``: (default) The user promises that indices are in bounds.
No additional checking will be performed. In practice, this means that
out-of-bounds indices in ``get()`` will be clipped, and out-of-bounds indices
in ``set()``, ``add()``, etc. will be dropped.
- ``"clip"``: clamp out of bounds indices into valid range.
- ``"drop"``: ignore out-of-bound indices.
- ``"fill"``: alias for ``"drop"``. For `get()`, the optional ``fill_value``
argument specifies the value that will be returned.
indices_are_sorted : bool
If True, the implementation will assume that the indices passed to ``at[]``
are sorted in ascending order, which can lead to more efficient execution
on some backends.
unique_indices : bool
If True, the implementation will assume that the indices passed to ``at[]``
are unique, which can result in more efficient execution on some backends.
fill_value : Any
Only applies to the ``get()`` method: the fill value to return for out-of-bounds
slices when `mode` is ``'fill'``. Ignored otherwise. Defaults to ``NaN`` for
inexact types, the largest negative value for signed types, the largest positive
value for unsigned types, and ``True`` for booleans.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> x = jnp.arange(5.0) * u.mV
>>> x.at[2].add(10 * u.mV)
Quantity([ 0. 1. 12. 3. 4.], "mV")
>>> x.at[2].get()
Quantity(2., "mV")
"""
return _IndexUpdateHelper(self)
@property
def mantissa(self) -> jax.typing.ArrayLike:
r"""
The raw numerical data of this quantity (without the unit).
In scientific notation :math:`x = a \times 10^{b}`, the *mantissa* is
the coefficient :math:`a`. For a ``Quantity``, it is the underlying
JAX/NumPy array (or Python scalar) that stores the numeric value.
Returns
-------
array_like
The mantissa array or scalar.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q.mantissa
3.0
See Also
--------
magnitude : Alias for ``mantissa``.
unit : The physical unit attached to this quantity.
"""
return self._mantissa
@property
def magnitude(self) -> jax.typing.ArrayLike:
"""
Alias for :attr:`mantissa`.
Returns
-------
array_like
The raw numerical data of this quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(5.0, unit=u.metre)
>>> q.magnitude
5.0
See Also
--------
mantissa : Primary accessor for the numerical data.
"""
return self.mantissa
[docs]
def update_mantissa(self, mantissa: PyTree) -> None:
"""
Replace the mantissa in-place, keeping the same unit.
The new mantissa must have the same shape and dtype as the current one.
Parameters
----------
mantissa : array_like
The new numerical data. Must not be a :class:`Quantity`.
Raises
------
ValueError
If *mantissa* is a ``Quantity``, or if shape/dtype do not match.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.update_mantissa(jnp.array([4.0, 5.0, 6.0]))
>>> q
Quantity([4. 5. 6.], "mV")
"""
self_value = self.mantissa
if isinstance(mantissa, Quantity):
raise ValueError("Cannot set the mantissa of a Quantity to another Quantity.")
if isinstance(mantissa, np.ndarray):
mantissa = jnp.asarray(mantissa, dtype=self.dtype)
elif isinstance(mantissa, jax.Array):
pass
else:
mantissa = jnp.asarray(mantissa, dtype=self.dtype)
# check
if mantissa.shape != jnp.shape(self_value):
raise ValueError(f"The shape of the original data is {jnp.shape(self_value)}, "
f"while we got {mantissa.shape}.")
if mantissa.dtype != jax.dtypes.result_type(self_value):
raise ValueError(f"The dtype of the original data is {jax.dtypes.result_type(self_value)}, "
f"while we got {mantissa.dtype}.")
self._mantissa = mantissa
@property
def dim(self) -> Dimension:
"""
The physical dimension of this quantity (e.g. length, mass, time).
The dimension is independent of scale (metres vs kilometres both have
the *length* dimension).
Returns
-------
Dimension
The physical dimension object.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(5.0, unit=u.metre)
>>> q.dim
m
See Also
--------
unit : The full unit (dimension + scale).
"""
return self.unit.dim
@dim.setter
def dim(self, value):
# Do not support setting the unit directly
raise NotImplementedError(
"Cannot set the dimension of a Quantity object directly,"
"Please create a new Quantity object with the dimension you want."
)
@property
def unit(self) -> 'Unit':
"""
The :class:`Unit` attached to this quantity.
The unit carries both the physical dimension and the scale factor
(e.g. ``mV`` has dimension ``voltage`` with scale ``1e-3``).
Returns
-------
Unit
The unit of this quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(5.0, unit=u.mV)
>>> q.unit
mV
See Also
--------
dim : The physical dimension without scale information.
mantissa : The numerical value.
"""
return self._unit
@unit.setter
def unit(self, value):
# Do not support setting the unit directly
raise NotImplementedError(
"Cannot set the unit of a Quantity object directly,"
"Please create a new Quantity object with the unit you want."
)
[docs]
def to(self, new_unit: Unit) -> 'Quantity':
"""
Convert this quantity to a different (compatible) unit.
The mantissa is rescaled so that the physical value stays the same,
and the returned ``Quantity`` carries *new_unit*.
Parameters
----------
new_unit : Unit
Target unit. Must have the same dimension as ``self.unit``.
Returns
-------
Quantity
A new ``Quantity`` expressed in *new_unit*.
Raises
------
UnitMismatchError
If *new_unit* has a different dimension.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.to(u.volt)
Quantity([0.001 0.002 0.003], "V")
See Also
--------
in_unit : Identical behaviour (``to`` delegates to ``in_unit``).
to_decimal : Convert to a plain number in the target unit.
"""
return self.in_unit(new_unit)
[docs]
def to_decimal(self, unit: Unit = UNITLESS) -> jax.typing.ArrayLike:
"""
Return the numerical value expressed in the given unit, without wrapping
the result in a ``Quantity``.
This is useful when you need a plain JAX array for downstream
computation that does not support units.
Parameters
----------
unit : Unit, optional
The reference unit. Defaults to ``UNITLESS``.
Returns
-------
array_like
A plain number or JAX array representing the quantity in *unit*.
Raises
------
TypeError
If *unit* is not a :class:`Unit`.
UnitMismatchError
If *unit* has a different dimension than ``self.unit``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.to_decimal(u.volt)
Array([0.001, 0.002, 0.003], dtype=float32)
See Also
--------
to : Convert while keeping the ``Quantity`` wrapper.
"""
if not isinstance(unit, Unit):
raise TypeError(f"Expected a Unit, but got {unit}.")
if not unit.has_same_dim(self.unit):
raise UnitMismatchError(
f"Cannot convert to the decimal number using a unit with different dimensions.",
self.unit,
unit,
)
if not unit.has_same_magnitude(self.unit):
return self.mantissa * (self.unit.magnitude / unit.magnitude)
else:
return self.mantissa
[docs]
def in_unit(self, unit: Unit, err_msg: str = None) -> 'Quantity':
"""
Convert this quantity to a compatible unit.
Behaves identically to :meth:`to`; kept for API compatibility.
Parameters
----------
unit : Unit
Target unit. Must share the same dimension as ``self.unit``.
err_msg : str, optional
Custom error message used when the dimensions do not match.
Returns
-------
Quantity
A new ``Quantity`` expressed in *unit*.
Raises
------
UnitMismatchError
If *unit* has a different dimension.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.in_unit(u.volt)
Quantity([0.001 0.002 0.003], "V")
"""
if not isinstance(unit, Unit):
raise TypeError(f"Expected a Unit, but got {unit}.")
if not unit.has_same_dim(self.unit):
if err_msg is None:
raise UnitMismatchError(f"Cannot convert to a unit with different dimensions.", self.unit, unit)
else:
raise UnitMismatchError(err_msg)
self_mag = self.unit.magnitude
target_mag = unit.magnitude
if self_mag == target_mag:
u = Quantity(self.mantissa, unit=unit)
else:
u = Quantity(self.mantissa * (self_mag / target_mag), unit=unit)
return u
[docs]
@staticmethod
def with_unit(mantissa: PyTree, unit: Unit):
"""
Create a :class:`Quantity` from a raw value and a unit.
This is a convenience factory that reads more naturally in some
contexts than the standard constructor.
Parameters
----------
mantissa : array_like or number
The numerical value(s).
unit : Unit
The physical unit.
Returns
-------
Quantity
A new ``Quantity`` with the given mantissa and unit.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> u.Quantity.with_unit(2.0, unit=u.metre)
Quantity(2., "m")
"""
return Quantity(mantissa, unit=unit)
@property
def is_unitless(self) -> bool:
"""
``True`` if this quantity is dimensionless (has no physical unit).
Returns
-------
bool
Whether the quantity is unitless.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> u.Quantity(5.0).is_unitless
True
>>> u.Quantity(5.0, unit=u.mV).is_unitless
False
"""
return self.unit.is_unitless
[docs]
def has_same_unit(self, other):
"""
Check whether this quantity shares the same physical dimension as *other*.
Two quantities that differ only in scale (e.g. ``mV`` vs ``V``) are
considered to have the same unit dimension.
Parameters
----------
other : Quantity or Unit
The object to compare with.
Returns
-------
bool
``True`` if both have identical physical dimensions.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> a = u.Quantity(1.0, unit=u.mV)
>>> b = u.Quantity(2.0, unit=u.volt)
>>> a.has_same_unit(b)
True
>>> c = u.Quantity(1.0, unit=u.second)
>>> a.has_same_unit(c)
False
"""
self_dim = get_dim(self.dim)
other_dim = get_dim(other.dim)
return (self_dim is other_dim) or (self_dim == other_dim)
def _format_value(self, precision: int | None = None) -> str:
"""Format the mantissa value as a string."""
m = self.mantissa
if isinstance(m, jax.Array):
value = m
else:
try:
value = jnp.asarray(m)
except TypeError:
value = m
if _is_tracer(value):
return str(value)
try:
if value.shape == ():
s = np.array_str(np.array([value]), precision=precision)
return s.replace("[", "").replace("]", "").strip()
# Use numpy's built-in summarization for large arrays
if value.size > 100:
kw = {}
if precision is not None:
kw['precision'] = precision
with np.printoptions(threshold=10, **kw):
return np.array_str(value)
return np.array_str(value, precision=precision)
except (TypeError, AttributeError):
return str(value)
[docs]
def repr_in_unit(
self,
precision: int | None = None,
) -> str:
"""
Return a human-readable string of this quantity in its current unit.
The format is ``"<value> <unit>"``, e.g. ``"3. mV"`` or
``"[1. 2. 3.] mV"``.
Parameters
----------
precision : int, optional
Number of significant digits. When *None* the value from
``numpy.get_printoptions`` is used.
Returns
-------
str
The formatted string.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> x = u.Quantity(25.0, unit=u.mV)
>>> x.repr_in_unit()
'25. mV'
>>> x.to(u.volt).repr_in_unit(3)
'0.025 V'
"""
s = self._format_value(precision=precision)
if self.unit.should_display_unit:
s += f" {str(self.unit)}"
return s.strip()
[docs]
def factorless(self) -> 'Quantity':
"""
Return an equivalent quantity whose unit has ``factor == 1.0``.
If the unit already has no extra factor the original object is
returned unchanged.
Returns
-------
Quantity
A quantity with the factor folded into the mantissa.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q.factorless()
Quantity(3., "mV")
"""
if self.unit.factor != 1.0:
return Quantity(self.mantissa * self.unit.factor, unit=self.unit.factorless())
else:
return self
@property
def dtype(self):
"""
The data type of the mantissa.
Returns
-------
dtype
The JAX/NumPy dtype of the underlying array.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.dtype
float32
"""
a = self.mantissa
if hasattr(a, 'dtype'):
return a.dtype
else:
if isinstance(a, bool):
return bool
elif isinstance(a, int):
return jax.dtypes.canonicalize_dtype(int)
elif isinstance(a, float):
return jax.dtypes.canonicalize_dtype(float)
elif isinstance(a, complex):
return jax.dtypes.canonicalize_dtype(complex)
else:
raise TypeError(f'Can not get dtype of {a}.')
@property
def shape(self) -> tuple[int, ...]:
"""
The shape of the mantissa array.
Returns
-------
tuple of int
Shape tuple, identical to ``jnp.shape(self.mantissa)``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.shape
(2, 2)
"""
return jnp.shape(self.mantissa)
@property
def ndim(self) -> int:
return jnp.ndim(self.mantissa)
@property
def imag(self) -> 'Quantity':
return Quantity(jnp.imag(self.mantissa), unit=self.unit)
@property
def real(self) -> 'Quantity':
return Quantity(jnp.real(self.mantissa), unit=self.unit)
@property
def size(self) -> int:
return jnp.size(self.mantissa)
@property
def nbytes(self) -> int:
"""Total bytes consumed by the mantissa array."""
return jnp.asarray(self.mantissa).nbytes
@property
def itemsize(self) -> int:
"""Length (in bytes) of one array element."""
return jnp.asarray(self.mantissa).itemsize
@property
def strides(self):
"""Tuple of byte-steps in each dimension (mirrors numpy.ndarray.strides)."""
return np.asarray(self.mantissa).strides
@property
def flat(self):
"""1-D iterator over the mantissa elements, unit preserved."""
for v in jnp.asarray(self.mantissa).flat:
yield Quantity(v, unit=self.unit)
@property
def T(self) -> 'Quantity':
return Quantity(jnp.asarray(self.mantissa).T, unit=self.unit)
@property
def mT(self) -> 'Quantity':
"""
Matrix transpose of the last two dimensions, preserving units.
The array must be at least 2-D.
Returns
-------
Quantity
The matrix-transposed quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.mT.shape
(2, 2)
"""
return Quantity(jnp.asarray(self.mantissa).mT, unit=self.unit)
@property
def isreal(self) -> jax.Array:
return jnp.isreal(self.mantissa)
@property
def isscalar(self) -> bool:
return self.ndim == 0
@property
def isfinite(self) -> jax.Array:
return jnp.isfinite(self.mantissa)
@property
def isinfinite(self) -> jax.Array:
return jnp.isinf(self.mantissa)
@property
def isinf(self) -> jax.Array:
return jnp.isinf(self.mantissa)
@property
def isnan(self) -> jax.Array:
return jnp.isnan(self.mantissa)
# ----------------------- #
# Python inherent methods #
# ----------------------- #
def __hash__(self):
"""
Hash the Quantity object.
Returns:
int: The hash value of the Quantity object.
"""
try:
return hash((np.asarray(self.mantissa).tobytes(), self.unit))
except Exception:
return hash((id(self.mantissa), self.unit))
def __repr__(self) -> str:
value_str = self._format_value()
unit_str = str(self.unit)
prefix = "Quantity("
if self.unit.should_display_unit:
suffix = f", \"{unit_str}\")"
else:
suffix = ")"
# Indent continuation lines to align with prefix
if "\n" in value_str:
indent = " " * len(prefix)
lines = value_str.split("\n")
value_str = lines[0] + "\n" + "\n".join(indent + line for line in lines[1:])
return f"{prefix}{value_str}{suffix}"
def __str__(self) -> str:
return self.repr_in_unit()
def __format__(self, format_spec) -> str:
if not format_spec:
return str(self)
# Block '%' format on quantities with units — "50% mV" is meaningless
if '%' in format_spec and not self.unit.is_unitless:
raise ValueError(
f"'%' format is not supported for Quantity with unit {str(self.unit)!r}. "
f"Convert to a dimensionless value first."
)
unit_str = str(self.unit)
show_unit = self.unit.should_display_unit
if self.shape == ():
formatted_value = format(self.mantissa, format_spec)
if not show_unit:
return formatted_value
return f"{formatted_value} {unit_str}"
else:
# Parse precision from standard format specs like .2f, .3e, .4g,
# 10.2f, +.2f, etc. Use a regex to extract the precision field.
m = re.match(r'^[^.]*\.(\d+)[feEgGn%]?$', format_spec)
if m is not None:
precision = int(m.group(1))
value = np.asarray(self.mantissa)
s = np.array_str(np.round(value, precision), precision=precision)
if not show_unit:
return s
return f"{s} {unit_str}"
return str(self)
def __iter__(self):
"""Solve the issue of DeviceArray.__iter__.
Details please see JAX issues:
- https://github.com/google/jax/issues/7713
- https://github.com/google/jax/pull/3821
"""
if self.ndim == 0:
raise TypeError("iteration over a 0-d Quantity is not allowed")
for i in range(self.shape[0]):
yield Quantity(self.mantissa[i], unit=self.unit)
def __getitem__(self, index) -> 'Quantity':
if isinstance(index, slice) and (index == _all_slice):
return Quantity(self.mantissa, unit=self.unit)
elif isinstance(index, tuple):
for x in index:
if isinstance(x, Quantity):
raise TypeError("Array indices must be integers or slices, not Array")
elif isinstance(index, Quantity):
raise TypeError("Array indices must be integers or slices, not Array")
return Quantity(self.mantissa[index], unit=self.unit)
def __setitem__(self, index, value: 'Quantity | jax.typing.ArrayLike'):
# check value
if not isinstance(value, Quantity):
if self.is_unitless:
value = Quantity(value)
else:
raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}")
value = value.in_unit(self.unit)
# check index
index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity))
# update
self_value = jnp.asarray(self.mantissa).at[index].set(value.mantissa)
self.update_mantissa(self_value)
[docs]
def scatter_add(
self,
index: jax.typing.ArrayLike,
value: 'Quantity | jax.typing.ArrayLike'
) -> 'Quantity':
"""
Return a copy with *value* added at *index*.
Parameters
----------
index : int or array_like
Target index (indices).
value : Quantity
The value to add. Must have the same unit dimension.
Returns
-------
Quantity
A new quantity with the update applied.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_add(0, u.Quantity(10.0, unit=u.mV))
Quantity([11. 2. 3.], "mV")
"""
# check value
if not isinstance(value, Quantity):
if self.is_unitless:
value = Quantity(value)
else:
raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}")
value = value.in_unit(self.unit)
# check index
index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity))
# scatter-add
self_value = jnp.asarray(self.mantissa)
self_value = self_value.at[index].add(value.mantissa)
return Quantity(self_value, unit=self.unit)
[docs]
def scatter_sub(
self,
index: jax.typing.ArrayLike,
value: 'Quantity | jax.typing.ArrayLike'
) -> 'Quantity':
"""
Return a copy with *value* subtracted at *index*.
Parameters
----------
index : int or array_like
Target index (indices).
value : Quantity
The value to subtract. Must have the same unit dimension.
Returns
-------
Quantity
A new quantity with the update applied.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_sub(0, u.Quantity(1.0, unit=u.mV))
Quantity([0. 2. 3.], "mV")
"""
return self.scatter_add(index, -value)
[docs]
def scatter_mul(
self,
index: jax.typing.ArrayLike,
value: 'Quantity | jax.typing.ArrayLike'
) -> 'Quantity':
"""
Return a copy with the element at *index* multiplied by *value*.
*value* must be dimensionless (a pure scale factor).
Parameters
----------
index : int or array_like
Target index (indices).
value : Quantity or number
Dimensionless scale factor.
Returns
-------
Quantity
A new quantity with the update applied.
Raises
------
TypeError
If *value* is not dimensionless.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_mul(0, u.Quantity(10.0))
Quantity([10. 2. 3.], "mV")
"""
# check value: scatter_mul requires a dimensionless scale factor
if not isinstance(value, Quantity):
value = Quantity(value)
if not value.is_unitless:
raise TypeError(
f"scatter_mul requires a dimensionless scale factor, "
f"but got {value}. Use Quantity.__mul__ for unit-changing multiplication."
)
# check index
index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity))
# scatter-mul
self_value = jnp.asarray(self.mantissa)
self_value = self_value.at[index].mul(value.mantissa)
return Quantity(self_value, unit=self.unit)
[docs]
def scatter_div(
self,
index: jax.typing.ArrayLike,
value: 'Quantity | jax.typing.ArrayLike'
) -> 'Quantity':
"""
Return a copy with the element at *index* divided by *value*.
*value* must be dimensionless (a pure scale factor).
Parameters
----------
index : int or array_like
Target index (indices).
value : Quantity or number
Dimensionless scale factor.
Returns
-------
Quantity
A new quantity with the update applied.
Raises
------
TypeError
If *value* is not dimensionless.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_div(0, u.Quantity(2.0))
Quantity([0.5 2. 3. ], "mV")
"""
# check value: scatter_div requires a dimensionless scale factor
if not isinstance(value, Quantity):
value = Quantity(value)
if not value.is_unitless:
raise TypeError(
f"scatter_div requires a dimensionless scale factor, "
f"but got {value}. Use Quantity.__truediv__ for unit-changing division."
)
# check index
index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity))
# scatter-div
self_value = jnp.asarray(self.mantissa)
self_value = self_value.at[index].divide(value.mantissa)
return Quantity(self_value, unit=self.unit)
[docs]
def scatter_max(
self,
index: jax.typing.ArrayLike,
value: 'Quantity | jax.typing.ArrayLike'
) -> 'Quantity':
"""
Return a copy where the element at *index* is the maximum of
the current value and *value*.
Parameters
----------
index : int or array_like
Target index (indices).
value : Quantity
The comparison value. Must have the same unit dimension.
Returns
-------
Quantity
A new quantity with the update applied.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_max(0, u.Quantity(10.0, unit=u.mV))
Quantity([10. 2. 3.], "mV")
"""
# check value
if not isinstance(value, Quantity):
if self.is_unitless:
value = Quantity(value)
else:
raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}")
value = value.in_unit(self.unit)
# check index
index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity))
# scatter-max
self_value = jnp.asarray(self.mantissa)
self_value = self_value.at[index].max(value.mantissa)
return Quantity(self_value, unit=self.unit)
[docs]
def scatter_min(
self,
index: jax.typing.ArrayLike,
value: 'Quantity | jax.typing.ArrayLike'
) -> 'Quantity':
"""
Return a copy where the element at *index* is the minimum of
the current value and *value*.
Parameters
----------
index : int or array_like
Target index (indices).
value : Quantity
The comparison value. Must have the same unit dimension.
Returns
-------
Quantity
A new quantity with the update applied.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_min(0, u.Quantity(0.5, unit=u.mV))
Quantity([0.5 2. 3. ], "mV")
"""
# check value
if not isinstance(value, Quantity):
if self.is_unitless:
value = Quantity(value)
else:
raise TypeError(f"Only Quantity can be assigned to Quantity. But got {value}")
value = value.in_unit(self.unit)
# check index
index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity))
# scatter-min
self_value = jnp.asarray(self.mantissa)
self_value = self_value.at[index].min(value.mantissa)
return Quantity(self_value, unit=self.unit)
# ---------- #
# operations #
# ---------- #
def __len__(self) -> int:
return len(self.mantissa)
def __neg__(self) -> 'Quantity':
return Quantity(self.mantissa.__neg__(), unit=self.unit)
def __pos__(self) -> 'Quantity':
return Quantity(self.mantissa.__pos__(), unit=self.unit)
def __abs__(self) -> 'Quantity':
return Quantity(self.mantissa.__abs__(), unit=self.unit)
def __invert__(self) -> 'Quantity':
return Quantity(self.mantissa.__invert__(), unit=self.unit)
def _comparison(self, other: Any, operator_str: str, operation: Callable):
other = _to_quantity(other)
try:
other_value = other.in_unit(self.unit).mantissa
except UnitMismatchError as e:
raise UnitMismatchError(
f"Cannot compare {self} {operator_str} {other}",
self.unit, other.unit,
) from e
return operation(self.mantissa, other_value)
def __eq__(self, oc) -> jax.typing.ArrayLike:
return self._comparison(oc, "==", operator.eq)
def __ne__(self, oc) -> jax.typing.ArrayLike:
return self._comparison(oc, "!=", operator.ne)
def __lt__(self, oc) -> jax.typing.ArrayLike:
return self._comparison(oc, "<", operator.lt)
def __le__(self, oc) -> jax.typing.ArrayLike:
return self._comparison(oc, "<=", operator.le)
def __gt__(self, oc) -> jax.typing.ArrayLike:
return self._comparison(oc, ">", operator.gt)
def __ge__(self, oc) -> jax.typing.ArrayLike:
return self._comparison(oc, ">=", operator.ge)
def _binary_operation(
self,
other,
value_operation: Callable,
unit_operation: Callable = lambda a, b: a,
fail_for_mismatch: bool = False,
operator_str: str = None,
inplace: bool = False,
):
"""
General implementation for binary operations.
Parameters
----------
other : {`Array`, `ndarray`, scalar}
The object with which the operation should be performed.
value_operation : function of two variables
The function with which the two objects are combined. For example,
`operator.mul` for a multiplication.
unit_operation : function of two variables, optional
The function with which the dimension of the resulting object is
calculated (as a function of the dimensions of the two involved
objects). For example, `operator.mul` for a multiplication. If not
specified, the dimensions of `self` are used for the resulting
object.
fail_for_mismatch : bool, optional
Whether to fail for a dimension mismatch between `self` and `other`
(defaults to ``False``)
operator_str : str, optional
The string to use for the operator in an error message.
inplace: bool, optional
Whether to do the operation in-place (defaults to ``False``).
"""
# format "other"
if not isinstance(other, Quantity):
other = _to_quantity(other)
# format the unit and mantissa of "other"
if fail_for_mismatch:
other = other.in_unit(
self.unit,
err_msg=f"Cannot calculate \n"
f"{self} {operator_str} {other}, "
f"because units do not match: {self.unit} != {other.unit}"
)
other_value = other.mantissa
other_unit = other.unit
# calculate the new unit and mantissa
r = Quantity(
value_operation(self.mantissa, other_value),
unit=unit_operation(self.unit, other_unit)
)
# update the mantissa in-place or not
if inplace:
self.update_mantissa(r.mantissa)
return self
else:
return r
def __add__(self, oc):
if isinstance(oc, SparseMatrix):
return oc.__radd__(self)
return self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+")
def __radd__(self, oc):
return self.__add__(oc)
def __iadd__(self, oc):
# a += b
return self._binary_operation(oc, operator.add, fail_for_mismatch=True, operator_str="+=", inplace=True)
def __sub__(self, oc):
if isinstance(oc, SparseMatrix):
return oc.__rsub__(self)
return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-")
def __rsub__(self, oc):
return Quantity(oc).__sub__(self)
def __isub__(self, oc):
# a -= b
return self._binary_operation(oc, operator.sub, fail_for_mismatch=True, operator_str="-=", inplace=True)
def __mul__(self, oc):
if isinstance(oc, SparseMatrix):
return oc.__rmul__(self)
r = self._binary_operation(oc, operator.mul, operator.mul)
return maybe_decimal(r)
def __rmul__(self, oc):
return self.__mul__(oc)
def __imul__(self, oc):
# a *= b
raise NotImplementedError("In-place multiplication is not supported, since it changes the unit.")
def __div__(self, oc):
# self / oc
if isinstance(oc, SparseMatrix):
return oc.__rdiv__(self)
r = self._binary_operation(oc, operator.truediv, operator.truediv)
return maybe_decimal(r)
def __idiv__(self, oc):
raise NotImplementedError("In-place division is not supported, since it changes the unit.")
def __truediv__(self, oc):
# self / oc
if isinstance(oc, SparseMatrix):
return oc.__rtruediv__(self)
return self.__div__(oc)
def __rdiv__(self, oc):
# oc / self
# division with swapped arguments
rdiv = lambda a, b: operator.truediv(b, a)
r = self._binary_operation(oc, rdiv, rdiv)
return maybe_decimal(r)
def __rtruediv__(self, oc):
# oc / self
return self.__rdiv__(oc)
def __itruediv__(self, oc):
# a /= b
raise NotImplementedError("In-place true division is not supported, since it changes the unit.")
def __floordiv__(self, oc):
# self // oc
if isinstance(oc, SparseMatrix):
return oc.__rfloordiv__(self)
r = self._binary_operation(oc, operator.floordiv, operator.truediv)
return maybe_decimal(r)
def __rfloordiv__(self, oc):
# oc // self
rdiv = lambda a, b: operator.truediv(b, a)
rfloordiv = lambda a, b: operator.floordiv(b, a)
r = self._binary_operation(oc, rfloordiv, rdiv)
return maybe_decimal(r)
def __ifloordiv__(self, oc):
# a //= b
raise NotImplementedError("In-place floor division is not supported, since it changes the unit.")
def __mod__(self, oc):
# self % oc
if isinstance(oc, SparseMatrix):
return oc.__rmod__(self)
r = self._binary_operation(oc, operator.mod, lambda ua, ub: ua, fail_for_mismatch=True, operator_str=r"%")
return maybe_decimal(r)
def __rmod__(self, oc):
# oc % self
oc = _to_quantity(oc)
r = oc._binary_operation(self, operator.mod, lambda ua, ub: ua, fail_for_mismatch=True, operator_str=r"%")
return maybe_decimal(r)
def __imod__(self, oc):
raise NotImplementedError("In-place mod is not supported, since it changes the unit.")
def __divmod__(self, oc):
return self.__floordiv__(oc), self.__mod__(oc)
def __rdivmod__(self, oc):
return self.__rfloordiv__(oc), self.__rmod__(oc)
def __matmul__(self, oc):
if isinstance(oc, SparseMatrix):
return oc.__rmatmul__(self)
r = self._binary_operation(oc, operator.matmul, operator.mul, operator_str="@")
return maybe_decimal(r)
def __rmatmul__(self, oc):
oc = _to_quantity(oc)
r = oc._binary_operation(self, operator.matmul, operator.mul, operator_str="@")
return maybe_decimal(r)
def __imatmul__(self, oc):
# a @= b
raise NotImplementedError("In-place matrix multiplication is not supported, since it changes the unit.")
# -------------------- #
def __pow__(self, oc):
self = self.factorless()
if compat_with_equinox:
try:
from equinox.internal._omega import ω # noqa
if isinstance(oc, ω):
return ω(self)
except (ImportError, ModuleNotFoundError):
pass
if isinstance(oc, Quantity):
if not oc.is_unitless:
raise ValueError(f"Cannot calculate {self} ** {oc}, the exponent has to be dimensionless")
oc = oc.mantissa
r = Quantity(jnp.array(self.mantissa) ** oc, unit=self.unit ** oc)
return maybe_decimal(r)
def __rpow__(self, oc):
# oc ** self
if not self.is_unitless:
raise ValueError(f"Cannot calculate {oc} ** {self}, the exponent has to be dimensionless")
return oc ** self.mantissa
def __ipow__(self, oc):
# a **= b
raise NotImplementedError("In-place power is not supported, since it changes the unit.")
def __and__(self, oc):
# Remove the unit from the result
raise NotImplementedError("Bitwise operations are not supported")
def __rand__(self, oc):
# Remove the unit from the result
raise NotImplementedError("Bitwise operations are not supported")
def __iand__(self, oc):
# Remove the unit from the result
raise NotImplementedError("Bitwise operations are not supported")
def __or__(self, oc):
# Remove the unit from the result
raise NotImplementedError("Bitwise operations are not supported")
def __ror__(self, oc):
# Remove the unit from the result
raise NotImplementedError("Bitwise operations are not supported")
def __ior__(self, oc):
# Remove the unit from the result
# a |= b
raise NotImplementedError("Bitwise operations are not supported")
def __xor__(self, oc):
# Remove the unit from the result
raise NotImplementedError("Bitwise operations are not supported")
def __rxor__(self, oc):
# Remove the unit from the result
raise NotImplementedError("Bitwise operations are not supported")
def __ixor__(self, oc) -> 'Quantity':
# Remove the unit from the result
# a ^= b
raise NotImplementedError("Bitwise operations are not supported")
def __lshift__(self, oc) -> 'Quantity':
# self << oc
if isinstance(oc, Quantity):
if not oc.is_unitless:
raise ValueError("The shift amount must be dimensionless")
oc = oc.mantissa
r = Quantity(self.mantissa << oc, unit=self.unit)
return maybe_decimal(r)
def __rlshift__(self, oc) -> 'Quantity | jax.typing.ArrayLike':
# oc << self
if not self.is_unitless:
raise ValueError("The shift amount must be dimensionless")
return oc << self.mantissa
def __ilshift__(self, oc) -> 'Quantity':
# self <<= oc
r = self.__lshift__(oc)
self.update_mantissa(r.mantissa)
return self
def __rshift__(self, oc) -> 'Quantity':
# self >> oc
if isinstance(oc, Quantity):
if not oc.is_unitless:
raise ValueError("The shift amount must be dimensionless")
oc = oc.mantissa
r = Quantity(self.mantissa >> oc, unit=self.unit)
return maybe_decimal(r)
def __rrshift__(self, oc) -> 'Quantity | jax.typing.ArrayLike':
# oc >> self
if not self.is_unitless:
raise ValueError("The shift amount must be dimensionless")
return oc >> self.mantissa
def __irshift__(self, oc) -> 'Quantity':
# self >>= oc
r = self.__rshift__(oc)
self.update_mantissa(r.mantissa)
return self
def __round__(self, ndigits: int = None) -> 'Quantity':
"""
Round the mantissa to the given number of decimals.
:param ndigits: The number of decimals to round to.
:return: The rounded Quantity.
"""
return Quantity(self.mantissa.__round__(ndigits), unit=self.unit)
def __reduce__(self):
"""
Method used by Pickle object serialization.
Returns ``(array_with_unit, (mantissa, unit))`` so that
``pickle.loads(pickle.dumps(q))`` reconstructs an identical Quantity
without bypassing ``__init__`` validation. Using ``array_with_unit``
(rather than ``Quantity`` directly) mirrors the pattern used by
``Unit.__reduce__`` and avoids issues with ``__slots__``.
Returns
-------
tuple
``(callable, args)`` such that ``callable(*args)`` reconstructs
the object.
"""
return _quantity_with_unit, (self.mantissa, self.unit)
# ----------------------- #
# NumPy methods #
# ----------------------- #
all = _wrap_function_remove_unit(jnp.all)
any = _wrap_function_remove_unit(jnp.any)
nonzero = _wrap_function_remove_unit(jnp.nonzero)
argmax = _wrap_function_remove_unit(jnp.argmax)
argmin = _wrap_function_remove_unit(jnp.argmin)
argsort = _wrap_function_remove_unit(jnp.argsort)
var = _wrap_function_change_unit(jnp.var, lambda val, unit: unit ** 2)
std = _wrap_function_keep_unit(jnp.std)
sum = _wrap_function_keep_unit(jnp.sum)
trace = _wrap_function_keep_unit(jnp.trace)
cumsum = _wrap_function_keep_unit(jnp.cumsum)
diagonal = _wrap_function_keep_unit(jnp.diagonal)
max = _wrap_function_keep_unit(jnp.max)
mean = _wrap_function_keep_unit(jnp.mean)
min = _wrap_function_keep_unit(jnp.min)
ptp = _wrap_function_keep_unit(jnp.ptp)
ravel = _wrap_function_keep_unit(jnp.ravel)
def __deepcopy__(self, memodict: dict):
return Quantity(
deepcopy(self.mantissa),
unit=self.unit.__deepcopy__(memodict)
)
[docs]
def round(
self,
decimals: int = 0,
) -> 'Quantity':
"""
Evenly round the mantissa to the given number of decimals.
Parameters
----------
decimals : int, optional
Number of decimal places (default ``0``). Negative values
round to positions left of the decimal point.
Returns
-------
Quantity
A new quantity with the rounded mantissa.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(1.567, unit=u.mV)
>>> q.round(1)
Quantity(1.6, "mV")
"""
return Quantity(jnp.round(self.mantissa, decimals), unit=self.unit)
[docs]
def astype(
self,
dtype: jax.typing.DTypeLike
) -> 'Quantity':
"""
Return a copy of this quantity with the mantissa cast to *dtype*.
Parameters
----------
dtype : str or dtype
Target data type (e.g. ``jnp.float64``).
Returns
-------
Quantity
A new quantity with the converted dtype.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.astype(jnp.float64).dtype
float64
"""
if dtype is None:
return Quantity(self.mantissa, unit=self.unit)
else:
return Quantity(jnp.astype(self.mantissa, dtype), unit=self.unit)
[docs]
def clip(
self,
min: 'Quantity | jax.typing.ArrayLike' = None,
max: 'Quantity | jax.typing.ArrayLike' = None,
) -> 'Quantity':
"""
Clip (limit) the values in the array to ``[min, max]``.
At least one of *min* or *max* must be given. Both must be
compatible with the unit of ``self``.
Parameters
----------
min : Quantity or array_like, optional
Minimum value.
max : Quantity or array_like, optional
Maximum value.
Returns
-------
Quantity
The clipped quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.clip(min=u.Quantity(1.5, unit=u.mV), max=u.Quantity(2.5, unit=u.mV))
Quantity([1.5 2. 2.5], "mV")
"""
_, min = unit_scale_align_to_first(self, min)
_, max = unit_scale_align_to_first(self, max)
return Quantity(jnp.clip(self.mantissa, min.mantissa, max.mantissa), unit=self.unit)
[docs]
def conj(self) -> 'Quantity':
"""
Return the complex conjugate, element-wise, preserving units.
Returns
-------
Quantity
The conjugated quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(1.0 + 2.0j, unit=u.mV)
>>> q.conj()
Quantity((1-2j), "mV")
"""
return Quantity(jnp.conj(self.mantissa), unit=self.unit)
[docs]
def conjugate(self) -> 'Quantity':
"""
Return the complex conjugate, element-wise.
Alias for :meth:`conj`.
Returns
-------
Quantity
The conjugated quantity.
"""
return Quantity(jnp.conjugate(self.mantissa), unit=self.unit)
[docs]
def copy(self) -> 'Quantity':
"""
Return a deep copy of this quantity.
Returns
-------
Quantity
An independent copy with the same mantissa and unit.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q2 = q.copy()
>>> q2
Quantity(3., "mV")
"""
return type(self)(jnp.copy(self.mantissa), unit=self.unit)
[docs]
def dot(self, b) -> 'Quantity':
"""
Dot product of two arrays.
The resulting unit is ``self.unit * b.unit``.
Parameters
----------
b : Quantity or array_like
Second operand.
Returns
-------
Quantity
The dot product.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> b = u.Quantity(jnp.array([1.0, 1.0, 1.0]), unit=u.mV)
>>> a.dot(b)
Quantity(6., "mV^2")
"""
r = self._binary_operation(b, jnp.dot, operator.mul, operator_str="@")
return maybe_decimal(r)
[docs]
def trace(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> 'Quantity':
"""
Sum along diagonals of the array, preserving units.
Parameters
----------
offset : int, optional
Offset of the diagonal from the main diagonal (default ``0``).
axis1 : int, optional
First axis of the 2-D sub-arrays (default ``0``).
axis2 : int, optional
Second axis of the 2-D sub-arrays (default ``1``).
Returns
-------
Quantity
The trace value(s).
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.eye(3), unit=u.mV)
>>> q.trace()
Quantity(3., "mV")
"""
return Quantity(jnp.trace(self.mantissa, offset=offset, axis1=axis1, axis2=axis2), unit=self.unit)
[docs]
def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> 'Quantity':
"""
Return specified diagonals, preserving units.
Parameters
----------
offset : int, optional
Offset from the main diagonal (default ``0``).
axis1 : int, optional
First axis (default ``0``).
axis2 : int, optional
Second axis (default ``1``).
Returns
-------
Quantity
The diagonal elements.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.diagonal()
Quantity([1. 4.], "mV")
"""
return Quantity(jnp.diagonal(self.mantissa, offset=offset, axis1=axis1, axis2=axis2), unit=self.unit)
[docs]
def outer(self, b: 'Quantity') -> 'Quantity':
"""
Outer product of two 1-D arrays.
The resulting unit is ``self.unit * b.unit``.
Parameters
----------
b : Quantity or array_like
Second operand.
Returns
-------
Quantity
The outer product matrix.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> b = u.Quantity(jnp.array([3.0, 4.0]), unit=u.second)
>>> a.outer(b).shape
(2, 2)
"""
b = _to_quantity(b)
r = self._binary_operation(b, jnp.outer, operator.mul, operator_str="outer")
return maybe_decimal(r)
[docs]
def cross(self, b: 'Quantity', axisa: int = -1, axisb: int = -1, axisc: int = -1, axis: int = None) -> 'Quantity':
"""
Cross product of two arrays.
The resulting unit is ``self.unit * b.unit``.
Parameters
----------
b : Quantity
Second operand.
axisa : int, optional
Axis of *self* that defines the vector(s) (default ``-1``).
axisb : int, optional
Axis of *b* that defines the vector(s) (default ``-1``).
axisc : int, optional
Axis of the result containing the cross product (default ``-1``).
axis : int, optional
Overrides *axisa*, *axisb*, and *axisc* simultaneously.
Returns
-------
Quantity
The cross product.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = u.Quantity(jnp.array([1.0, 0.0, 0.0]), unit=u.mV)
>>> b = u.Quantity(jnp.array([0.0, 1.0, 0.0]), unit=u.second)
>>> a.cross(b)
Quantity([0. 0. 1.], "mV * s")
"""
b = _to_quantity(b)
kwargs = dict(axisa=axisa, axisb=axisb, axisc=axisc)
if axis is not None:
kwargs['axis'] = axis
result_mantissa = jnp.cross(self.mantissa, b.mantissa, **kwargs)
result_unit = self.unit * b.unit
r = Quantity(result_mantissa, unit=result_unit)
return maybe_decimal(r)
[docs]
def searchsorted(self, v, side: str = 'left', sorter=None) -> jax.Array:
"""Find indices where elements should be inserted to maintain order."""
if isinstance(v, Quantity):
v = v.in_unit(self.unit).mantissa
return jnp.searchsorted(self.mantissa, v, side=side, sorter=sorter)
[docs]
def fill(self, value: 'Quantity') -> 'Quantity':
"""Fill the array with a scalar mantissa."""
fail_for_dimension_mismatch(self, value, "fill")
self[:] = value
return self
[docs]
def flatten(self) -> 'Quantity':
"""
Return a 1-D copy of this quantity.
Returns
-------
Quantity
Flattened quantity with the same unit.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.flatten()
Quantity([1. 2. 3. 4.], "mV")
"""
return Quantity(jnp.reshape(self.mantissa, -1), unit=self.unit)
[docs]
def item(self, *args) -> 'Quantity':
"""
Extract a single element as a scalar ``Quantity``.
Parameters
----------
*args : int
Index into the flat array.
Returns
-------
Quantity
A 0-D ``Quantity`` containing the selected element.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([10.0, 20.0]), unit=u.mV)
>>> q.item(0)
Quantity(10., "mV")
"""
return Quantity(self.mantissa.item(*args), unit=self.unit)
[docs]
def prod(self, *args, **kwds) -> 'Quantity':
"""
Return the product of array elements over the given axis.
The unit of the result is ``self.unit ** n`` where *n* is the number
of elements multiplied together.
Returns
-------
Quantity
The product.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([2.0, 3.0]), unit=u.mV)
>>> q.prod()
Quantity(6., "mV^2")
"""
self = self.factorless()
prod_res = jnp.prod(self.mantissa, *args, **kwds)
# Calculating the correct dimensions is not completly trivial (e.g.
# like doing self.dim**self.size) because prod can be called on
# multidimensional arrays along a certain axis.
# Our solution: Use a "dummy matrix" containing a 1 (without units) at
# each entry and sum it, using the same keyword arguments as provided.
# The result gives the exponent for the dimensions.
# This relies on sum and prod having the same arguments, which is true
# now and probably remains like this in the future
dim_exponent = jnp.ones_like(self.mantissa).sum(*args, **kwds)
# The result is possibly multidimensional but all entries should be
# identical
if dim_exponent.size > 1:
dim_exponent = dim_exponent.ravel()[0]
r = Quantity(jnp.array(prod_res), unit=self.unit ** dim_exponent)
return maybe_decimal(r)
[docs]
def nanprod(self, *args, **kwds) -> 'Quantity':
"""
Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones.
When reducing along a specific axis, the number of non-NaN elements
must be the same for every position in the result so that a single
unit exponent can be assigned. If the non-NaN counts differ and the
quantity is not dimensionless, a ``ValueError`` is raised.
Returns
-------
Quantity
The product (NaNs treated as ones).
Raises
------
ValueError
If the non-NaN counts are not uniform along the reduction axis
for a non-dimensionless quantity.
"""
self = self.factorless()
prod_res = jnp.nanprod(self.mantissa, *args, **kwds)
if self.is_unitless:
return maybe_decimal(Quantity(jnp.array(prod_res), unit=self.unit))
# Count non-NaN elements along the reduction axis.
nan_mask = jnp.isnan(self.mantissa)
non_nan_counts = jnp.sum(jnp.where(nan_mask, 0, 1), *args, **kwds)
# Verify uniform counts when axis is not None (result is not scalar).
if non_nan_counts.ndim > 0:
if not jnp.all(non_nan_counts == non_nan_counts.ravel()[0]):
raise ValueError(
"nanprod over an axis with non-uniform NaN counts is not "
"supported for quantities with units, because the resulting "
"elements would have different unit exponents."
)
dim_exponent = non_nan_counts.ravel()[0]
else:
dim_exponent = non_nan_counts
r = Quantity(jnp.array(prod_res), unit=self.unit ** dim_exponent)
return maybe_decimal(r)
[docs]
def cumprod(self, *args, **kwds):
"""
Return the cumulative product of elements along a given axis.
Because each position in the result corresponds to a different number
of multiplied elements, the unit exponent varies across the output.
This is only representable when the quantity is dimensionless.
Returns
-------
Quantity
The cumulative product.
Raises
------
TypeError
If the quantity is not dimensionless.
"""
if not self.is_unitless:
raise TypeError(
"cumprod is not supported for quantities with units "
f"(has unit {self.unit}), because each element of the result "
"would have a different unit exponent. "
"Use .prod() for a single reduction, or convert to "
"dimensionless first."
)
return maybe_decimal(
Quantity(jnp.cumprod(self.mantissa, *args, **kwds), unit=self.unit)
)
[docs]
def nancumprod(self, *args, **kwds):
"""
Return the cumulative product of elements along a given axis,
treating NaNs as ones.
Because each position in the result corresponds to a different number
of multiplied elements, the unit exponent varies across the output.
This is only representable when the quantity is dimensionless.
Returns
-------
Quantity
The cumulative product (NaNs treated as ones).
Raises
------
TypeError
If the quantity is not dimensionless.
"""
if not self.is_unitless:
raise TypeError(
"nancumprod is not supported for quantities with units "
f"(has unit {self.unit}), because each element of the result "
"would have a different unit exponent. "
"Use .nanprod() for a single reduction, or convert to "
"dimensionless first."
)
return maybe_decimal(
Quantity(jnp.nancumprod(self.mantissa, *args, **kwds), unit=self.unit)
)
[docs]
def put(self, indices, values) -> 'Quantity':
"""Replaces specified elements of an array with given values.
Parameters
----------
indices: array_like
Target indices, interpreted as integers.
values: array_like
Values to place in the array at target indices.
"""
fail_for_dimension_mismatch(self, values, "put")
self.__setitem__(indices, values)
return self
[docs]
def repeat(self, repeats, axis=None) -> 'Quantity':
"""
Repeat elements of the array.
Parameters
----------
repeats : int or array of ints
Number of repetitions for each element.
axis : int, optional
Axis along which to repeat.
Returns
-------
Quantity
The repeated quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.repeat(2)
Quantity([1. 1. 2. 2.], "mV")
"""
r = jnp.repeat(self.mantissa, repeats=repeats, axis=axis)
return Quantity(r, unit=self.unit)
[docs]
def reshape(self, shape, order='C') -> 'Quantity':
"""
Return a quantity with the same data but a new shape.
Parameters
----------
shape : int or tuple of ints
New shape.
order : {'C', 'F'}, optional
Memory layout order (default ``'C'``).
Returns
-------
Quantity
Reshaped quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.reshape((3, 1)).shape
(3, 1)
"""
return Quantity(jnp.reshape(self.mantissa, shape, order=order), unit=self.unit)
[docs]
def resize(self, new_shape) -> 'Quantity':
"""Change shape and size of array in-place."""
self.update_mantissa(jnp.resize(self.mantissa, new_shape))
return self
[docs]
def sort(self, axis=-1, stable=True, order=None) -> 'Quantity':
"""
Sort the array in-place along the given axis.
Parameters
----------
axis : int, optional
Axis along which to sort (default ``-1``).
stable : bool, optional
Whether to use a stable sort (default ``True``).
order : str or list of str, optional
Field ordering for structured arrays.
Returns
-------
Quantity
``self``, with the mantissa sorted in-place.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([3.0, 1.0, 2.0]), unit=u.mV)
>>> q.sort()
Quantity([1. 2. 3.], "mV")
"""
self.update_mantissa(jnp.sort(self.mantissa, axis=axis, stable=stable, order=order))
return self
[docs]
def squeeze(self, axis=None) -> 'Quantity':
"""
Remove length-one axes from the array.
Parameters
----------
axis : int or tuple of ints, optional
Axes to remove. If ``None``, all length-one axes are removed.
Returns
-------
Quantity
The squeezed quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[[1.0]]]), unit=u.mV)
>>> q.squeeze().shape
()
"""
return Quantity(jnp.squeeze(self.mantissa, axis=axis), unit=self.unit)
[docs]
def swapaxes(self, axis1, axis2) -> 'Quantity':
"""
Interchange two axes of the array.
Parameters
----------
axis1 : int
First axis.
axis2 : int
Second axis.
Returns
-------
Quantity
The quantity with axes swapped.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.swapaxes(0, 1).shape
(2, 2)
"""
return Quantity(jnp.swapaxes(self.mantissa, axis1, axis2), unit=self.unit)
[docs]
def split(self, indices_or_sections, axis=0) -> 'list[Quantity]':
"""
Split the array into multiple sub-arrays.
Parameters
----------
indices_or_sections : int or 1-D array
If an integer *N*, the array is divided into *N* equal parts.
If a sorted 1-D array of indices, the entries indicate split
points along *axis*.
axis : int, optional
Axis along which to split (default ``0``).
Returns
-------
list of Quantity
Sub-arrays, each carrying the same unit.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> parts = q.split(3)
>>> len(parts)
3
"""
return [Quantity(a, unit=self.unit) for a in jnp.split(self.mantissa, indices_or_sections, axis=axis)]
[docs]
def take(
self,
indices,
axis=None,
mode=None,
unique_indices=False,
indices_are_sorted=False,
fill_value=None,
) -> 'Quantity':
"""
Select elements from the array at the given indices.
Parameters
----------
indices : array_like
Indices of the values to extract.
axis : int, optional
Axis along which to take (default flattened).
mode : str, optional
Out-of-bounds index handling.
unique_indices : bool, optional
Hint that indices are unique.
indices_are_sorted : bool, optional
Hint that indices are sorted.
fill_value : Quantity or scalar, optional
Value for out-of-bounds positions when *mode* is ``'fill'``.
Returns
-------
Quantity
The selected elements.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([10.0, 20.0, 30.0]), unit=u.mV)
>>> q.take(jnp.array([0, 2]))
Quantity([10. 30.], "mV")
"""
if isinstance(fill_value, Quantity):
fail_for_dimension_mismatch(self, fill_value, "take")
fill_value = unit_scale_align_to_first(self, fill_value)[1].mantissa
elif fill_value is not None:
if not self.is_unitless:
raise TypeError(f"fill_value must be a Quantity when the unit {self.unit}. But got {fill_value}")
return Quantity(
jnp.take(
self.mantissa,
indices=indices,
axis=axis,
mode=mode,
unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted,
fill_value=fill_value
),
unit=self.unit
)
[docs]
def tolist(self):
"""
Convert the array to a (nested) Python list of ``Quantity`` scalars.
Each leaf element is a 0-D ``Quantity`` with the same unit.
Returns
-------
list or Quantity
A nested list of scalar ``Quantity`` objects, or a single
``Quantity`` for 0-D arrays.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.tolist()
[Quantity(1., "mV"), Quantity(2., "mV")]
"""
if isinstance(self.mantissa, numbers.Number):
list_mantissa = self.mantissa
else:
list_mantissa = self.mantissa.tolist()
return _replace_with_array(list_mantissa, self.unit)
[docs]
def transpose(self, *axes) -> 'Quantity':
"""
Return the array with axes transposed.
For a 2-D array this is the standard matrix transpose.
Parameters
----------
*axes : None, tuple of ints, or n ints
If omitted, axes are reversed. Otherwise specifies the
permutation.
Returns
-------
Quantity
Transposed quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.transpose().shape
(2, 2)
"""
return Quantity(jnp.transpose(self.mantissa, *axes), unit=self.unit)
[docs]
def tile(self, reps) -> 'Quantity':
"""
Construct an array by repeating this quantity.
Parameters
----------
reps : int or array_like
Number of repetitions along each axis.
Returns
-------
Quantity
The tiled quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.tile(2)
Quantity([1. 2. 1. 2.], "mV")
"""
return Quantity(jnp.tile(self.mantissa, reps), unit=self.unit)
[docs]
def view(self, *args, dtype=None) -> 'Quantity':
r"""New view of array with the same data.
This function is compatible with pytorch syntax.
Returns a new tensor with the same data as the :attr:`self` tensor but of a
different :attr:`shape`.
The returned tensor shares the same data and must have the same number
of elements, but may have a different size. For a tensor to be viewed, the new
view size must be compatible with its original size and stride, i.e., each new
view dimension must either be a subspace of an original dimension, or only span
across original dimensions :math:`d, d+1, \dots, d+k` that satisfy the following
contiguity-like condition that :math:`\forall i = d, \dots, d+k-1`,
.. math::
\text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1]
Otherwise, it will not be possible to view :attr:`self` tensor as :attr:`shape`
without copying it (e.g., via :meth:`contiguous`). When it is unclear whether a
:meth:`view` can be performed, it is advisable to use :meth:`reshape`, which
returns a view if the shapes are compatible, and copies (equivalent to calling
:meth:`contiguous`) otherwise.
Args:
shape (int...): the desired size
Example::
>>> import jax.numpy as jnp, saiunit
>>> x = saiunit.Quantity(jnp.ones((4, 4)))
>>> x.shape
(4, 4)
>>> y = x.view(16)
>>> y.shape
(16,)
>>> z = x.view(2, 8)
>>> z.shape
(2, 8)
.. method:: view(dtype) -> Tensor
:noindex:
Returns a new tensor with the same data as the :attr:`self` tensor but of a
different :attr:`dtype`.
If the element size of :attr:`dtype` is different than that of ``self.dtype``,
then the size of the last dimension of the output will be scaled
proportionally. For instance, if :attr:`dtype` element size is twice that of
``self.dtype``, then each pair of elements in the last dimension of
:attr:`self` will be combined, and the size of the last dimension of the output
will be half that of :attr:`self`. If :attr:`dtype` element size is half that
of ``self.dtype``, then each element in the last dimension of :attr:`self` will
be split in two, and the size of the last dimension of the output will be
double that of :attr:`self`. For this to be possible, the following conditions
must be true:
* ``self.dim()`` must be greater than 0.
* ``self.stride(-1)`` must be 1.
Additionally, if the element size of :attr:`dtype` is greater than that of
``self.dtype``, the following conditions must be true as well:
* ``self.size(-1)`` must be divisible by the ratio between the element
sizes of the dtypes.
* ``self.storage_offset()`` must be divisible by the ratio between the
element sizes of the dtypes.
* The strides of all dimensions, except the last dimension, must be
divisible by the ratio between the element sizes of the dtypes.
If any of the above conditions are not met, an error is thrown.
Args:
dtype (:class:`dtype`): the desired dtype
Example::
>>> x = brainstate.random.randn(4, 4)
>>> x
Array([[ 0.9482, -0.0310, 1.4999, -0.5316],
[-0.1520, 0.7472, 0.5617, -0.8649],
[-2.4724, -0.0334, -0.2976, -0.8499],
[-0.2109, 1.9913, -0.9607, -0.6123]])
>>> x.dtype
brainstate.math.float32
>>> y = x.view(numpy.int32)
>>> y
tensor([[ 1064483442, -1124191867, 1069546515, -1089989247],
[-1105482831, 1061112040, 1057999968, -1084397505],
[-1071760287, -1123489973, -1097310419, -1084649136],
[-1101533110, 1073668768, -1082790149, -1088634448]],
dtype=numpy.int32)
>>> y[0, 0] = 1000000000
>>> x
tensor([[ 0.0047, -0.0310, 1.4999, -0.5316],
[-0.1520, 0.7472, 0.5617, -0.8649],
[-2.4724, -0.0334, -0.2976, -0.8499],
[-0.2109, 1.9913, -0.9607, -0.6123]])
>>> x.view(numpy.complex64)
tensor([[ 0.0047-0.0310j, 1.4999-0.5316j],
[-0.1520+0.7472j, 0.5617-0.8649j],
[-2.4724-0.0334j, -0.2976-0.8499j],
[-0.2109+1.9913j, -0.9607-0.6123j]])
>>> x.view(numpy.complex64).size
[4, 2]
>>> x.view(numpy.uint8)
tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22,
8, 191],
[227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106,
93, 191],
[205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147,
89, 191],
[ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191,
28, 191]], dtype=uint8)
>>> x.view(numpy.uint8).size
[4, 16]
"""
if len(args) == 0:
if dtype is None:
raise ValueError('Provide dtype or shape.')
else:
return Quantity(self.mantissa.view(dtype), unit=self.unit)
else:
if isinstance(args[0], int): # shape
if dtype is not None:
raise ValueError('Provide one of dtype or shape. Not both.')
return Quantity(self.mantissa.reshape(*args), unit=self.unit)
else: # dtype
assert not isinstance(args[0], int)
assert dtype is None
return Quantity(self.mantissa.view(args[0]), unit=self.unit)
# ------------------
# NumPy support
# ------------------
def __array__(self, dtype: jax.typing.DTypeLike | None = None) -> np.ndarray:
"""Support ``numpy.array()`` and ``numpy.asarray()`` functions."""
if self.dim.is_dimensionless:
return np.asarray(self.to_decimal(), dtype=dtype)
else:
raise TypeError(
f"Only dimensionless quantities can be "
f"converted to NumPy arrays. But got {self}"
)
def __float__(self):
if self.dim.is_dimensionless and self.ndim == 0:
return float(self.to_decimal())
else:
raise TypeError(
"Only dimensionless scalar quantities can be "
f"converted to Python scalars. But got {self}"
)
def __int__(self):
if self.dim.is_dimensionless and self.ndim == 0:
return int(self.to_decimal())
else:
raise TypeError(
"only dimensionless scalar quantities can be "
f"converted to Python scalars. But got {self}"
)
def __index__(self):
if self.dim.is_dimensionless:
return operator.index(self.to_decimal())
else:
raise TypeError(
"only dimensionless quantities can be "
f"converted to a Python index. But got {self}"
)
# ----------------------
# PyTorch compatibility
# ----------------------
[docs]
def unsqueeze(self, axis: int) -> 'Quantity':
"""
Insert a length-one axis (PyTorch-style alias for :meth:`expand_dims`).
Parameters
----------
axis : int
Position where the new axis is inserted.
Returns
-------
Quantity
The quantity with an extra dimension.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.unsqueeze(0).shape
(1, 2)
"""
return Quantity(jnp.expand_dims(self.mantissa, axis), unit=self.unit)
[docs]
def expand_dims(self, axis: int | Sequence[int]) -> 'Quantity':
"""
Insert new axes at the given positions.
Parameters
----------
axis : int or tuple of ints
Position(s) where the new axis (axes) are placed.
Returns
-------
Quantity
The expanded quantity.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.expand_dims(0).shape
(1, 2)
"""
return Quantity(jnp.expand_dims(self.mantissa, axis), unit=self.unit)
[docs]
def expand_as(self, array: 'Quantity | jax.typing.ArrayLike') -> 'Quantity':
"""
Expand an array to a shape of another array.
Parameters
----------
array : Quantity
Returns
-------
expanded : Quantity
A readonly view on the original array with the given shape of array. It is
typically not contiguous. Furthermore, more than one element of a
expanded array may refer to a single memory location.
"""
if isinstance(array, Quantity):
fail_for_dimension_mismatch(self, array, "expand_as (Quantity)")
array = array.mantissa
return Quantity(jnp.broadcast_to(self.mantissa, array), unit=self.unit)
[docs]
def pow(self, oc) -> 'Quantity':
"""
Raise this quantity to the power *oc*.
The exponent must be dimensionless. The resulting unit is
``self.unit ** oc``.
Parameters
----------
oc : int, float, or dimensionless Quantity
The exponent.
Returns
-------
Quantity
``self ** oc``.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(2.0, unit=u.mV)
>>> q.pow(2)
Quantity(4., "mV^2")
"""
return self.__pow__(oc)
[docs]
def clone(self) -> 'Quantity':
"""
Return a copy of this quantity (PyTorch-style alias for :meth:`copy`).
Returns
-------
Quantity
An independent copy.
Examples
--------
.. code-block:: python
>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q.clone()
Quantity(3., "mV")
"""
return self.copy()
[docs]
def tree_flatten(self) -> tuple[tuple[jax.typing.ArrayLike], Unit]:
"""
Tree flattens the data.
Returns:
The data and the dimension.
"""
return (self.mantissa,), self.unit
[docs]
@classmethod
def tree_unflatten(cls, unit, values) -> 'Quantity':
"""
Tree unflattens the data.
Args:
unit: The unit.
values: The data.
Returns:
The Quantity object.
"""
return cls(*values, unit=unit)
def cuda(self, device=None) -> 'Quantity':
device = jax.devices('cuda')[0] if device is None else device
self.update_mantissa(jax.device_put(self.mantissa, device))
return self
def cpu(self, device=None) -> 'Quantity':
device = jax.devices('cpu')[0] if device is None else device
self.update_mantissa(jax.device_put(self.mantissa, device))
return self
# dtype exchanging #
# ---------------- #
def half(self) -> 'Quantity':
return Quantity(jnp.asarray(self.mantissa, dtype=jnp.float16), unit=self.unit)
def float(self) -> 'Quantity':
return Quantity(jnp.asarray(self.mantissa, dtype=jnp.float32), unit=self.unit)
def double(self) -> 'Quantity':
return Quantity(jnp.asarray(self.mantissa, dtype=jnp.float64), unit=self.unit)
# ---------------------------------------------------------------------------
# _IndexUpdateHelper
# ---------------------------------------------------------------------------
class _IndexUpdateHelper:
"""
Helper property for index update functionality.
"""
__slots__ = ("quantity",)
def __init__(self, quantity: Quantity):
if not isinstance(quantity, Quantity):
raise TypeError(f"quantity must be a Quantity object, but got {quantity}")
self.quantity = quantity
def __getitem__(self, index: Any) -> '_IndexUpdateRef':
return _IndexUpdateRef(index, self.quantity)
def __repr__(self):
return f"_IndexUpdateHelper({self.quantity})"
# ---------------------------------------------------------------------------
# _IndexUpdateRef
# ---------------------------------------------------------------------------
class _IndexUpdateRef:
"""
Helper object to call indexed update functions for an (advanced) index.
This object references a source array and a specific indexer into that array.
Methods on this object return copies of the source array that have been
modified at the positions specified by the indexer.
"""
__slots__ = ("quantity", "index", "mantissa_at", "unit")
def __init__(self, index, quantity: Quantity):
self.index = jax.tree.map(_element_not_quantity, index, is_leaf=lambda x: isinstance(x, Quantity))
self.quantity = quantity
self.mantissa_at = jnp.asarray(quantity.mantissa).at
self.unit = quantity.unit
def __repr__(self) -> str:
return f"_IndexUpdateRef({self.quantity}, {self.index!r})"
def get(
self,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None,
fill_value: StaticScalar | None = None
) -> Quantity:
"""Equivalent to ``x[idx]``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexing <numpy.doc.indexing>` ``x[idx]``. This function differs from
the usual array indexing syntax in that it allows additional keyword
arguments ``indices_are_sorted`` and ``unique_indices`` to be passed.
"""
if fill_value is not None:
fill_value = Quantity(fill_value).in_unit(self.unit).mantissa
return Quantity(
self.mantissa_at[self.index].get(
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode,
fill_value=fill_value
),
unit=self.unit
)
def set(
self,
values: Any,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None,
) -> Quantity:
"""Pure equivalent of ``x[idx] = y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:`indexed assignment <numpy.doc.indexing>` ``x[idx] = y``.
"""
values = Quantity(values).in_unit(self.unit).mantissa
return Quantity(
self.mantissa_at[self.index].set(
values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode,
),
unit=self.unit
)
def add(
self,
values: Any,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None
) -> Quantity:
"""Pure equivalent of ``x[idx] += y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] += y``.
"""
values = Quantity(values).in_unit(self.unit).mantissa
return Quantity(
self.mantissa_at[self.index].add(
values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode
),
unit=self.unit
)
def multiply(
self,
values: Any,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None
) -> Quantity:
"""Pure equivalent of ``x[idx] *= y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] *= y``.
"""
values = Quantity(values)
return Quantity(
self.mantissa_at[self.index].multiply(
values.mantissa,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode
),
unit=self.unit * values.unit
)
mul = multiply
def divide(
self,
values: Any,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None
) -> Quantity:
"""Pure equivalent of ``x[idx] /= y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] /= y``.
"""
values = Quantity(values)
return Quantity(
self.mantissa_at[self.index].divide(
values.mantissa,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode
),
unit=self.unit / values.unit
)
div = divide
def power(
self,
values: Any,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None
) -> Quantity:
"""Pure equivalent of ``x[idx] **= y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] **= y``.
"""
if not isinstance(values, int):
raise TypeError(f"values must be an integer, but got {values}")
return Quantity(
self.mantissa_at[self.index].power(
values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode
),
unit=self.unit ** values
)
def min(
self,
values: Any,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None
) -> Quantity:
"""Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>`
``x[idx] = minimum(x[idx], y)``.
"""
values = Quantity(values).in_unit(self.unit).mantissa
return Quantity(
self.mantissa_at[self.index].min(
values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode
),
unit=self.unit
)
def max(
self,
values: Any,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None
) -> Quantity:
"""Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>`
``x[idx] = maximum(x[idx], y)``.
"""
values = Quantity(values).in_unit(self.unit).mantissa
return Quantity(
self.mantissa_at[self.index].max(
values,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode
),
unit=self.unit
)
def apply(
self,
mantissa_fun: Callable[[jax.typing.ArrayLike], jax.typing.ArrayLike],
unit_fun: Callable[[Unit], Unit] | None = None,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | None = None
) -> Quantity:
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
Returns the value of ``x`` that would result from applying the unary
function ``func`` to ``x`` at the given indices. This is similar to
``x.at[idx].set(func(x[idx]))``, but differs in the case of repeated indices:
in ``x.at[idx].apply(func)``, repeated indices result in the function being
applied multiple times.
Note that in the current implementation, ``scatter_apply`` is not compatible
with automatic differentiation.
Parameters
----------
mantissa_fun : callable
Applied to the mantissa values at the given indices.
unit_fun : callable, optional
Transforms the unit of the result. If omitted the unit is preserved
(unit-preserving operations such as ``jnp.abs``).
"""
result_unit = unit_fun(self.unit) if unit_fun is not None else self.unit
return Quantity(
self.mantissa_at[self.index].apply(
mantissa_fun,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode
),
unit=result_unit
)
# ---------------------------------------------------------------------------
# _replace_with_array
# ---------------------------------------------------------------------------
def _replace_with_array(seq, unit):
"""
Replace all the elements in the list with an equivalent `Array`
with the given `unit`.
"""
# No recursion needed for single values
if not isinstance(seq, list):
return Quantity(seq, unit=unit)
def top_replace(s):
"""
Recursively descend into the list.
"""
for i in s:
if not isinstance(i, list):
yield Quantity(i, unit=unit)
else:
yield list(top_replace(i))
return list(top_replace(seq))