Quantity

Quantity#

class saiunit.Quantity(mantissa, unit=Unit('1'), dtype=None)#

A numerical value paired with a physical unit.

Quantity is the central data structure in saiunit. It stores a mantissa (the raw numerical data, typically a JAX array) together with a Unit that describes the physical dimensions and scale. Arithmetic on Quantity objects automatically tracks and checks units, raising UnitMismatchError when incompatible quantities are combined.

Quantity is registered as a JAX pytree, so it works transparently with jax.jit, jax.grad, jax.vmap, and other JAX transformations.

Parameters:
  • mantissa (Any | Unit) – The numerical value(s). If a Unit is passed, the mantissa is set to 1.0 and that unit is adopted. If a Quantity is passed, its mantissa and unit are used (converted to unit when given).

  • unit (saiunit.Unit | Array | ndarray | bool | number | bool | int | float | complex | str | None) – The physical unit. Defaults to UNITLESS.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – If provided, the mantissa is cast to this dtype on construction.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> # Scalar with unit
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q
Quantity(3., "mV")
>>> # Array with unit via multiplication shorthand
>>> arr = jnp.array([1.0, 2.0, 3.0]) * u.mV
>>> arr.shape
(3,)
>>> # From a Unit object directly
>>> u.Quantity(u.metre)
Quantity(1., "m")

See also

Unit

Represents a physical unit (dimension + scale).

compatible_with_equinox

Toggle Equinox interoperability.

all(axis=None, out=None, keepdims=False, *, where=None)#

Test whether all array elements along a given axis evaluate to True.

JAX implementation of numpy.all().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – Input array.

  • axis (int | Sequence[int] | None) – int or array, default=None. Axis along which to be tested. If None, tests along all the axes.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Array | ndarray | bool | number | bool | int | float | complex | None) – int or array of boolean dtype, default=None. The elements to be used in the test. Array should be broadcast compatible to the input.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of boolean values.

Examples

By default, jnp.all tests for True values along all the axes.

>>> x = jnp.array([[True, True, True, False],
...                [True, False, True, False],
...                [True, True, False, False]])
>>> jnp.all(x)
Array(False, dtype=bool)

If axis=0, tests for True values along axis 0.

>>> jnp.all(x, axis=0)
Array([ True, False, False, False], dtype=bool)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.all(x, axis=0, keepdims=True)
Array([[ True, False, False, False]], dtype=bool)

To include specific elements in testing for True values, you can use a``where``.

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.all(x, axis=0, keepdims=True, where=where)
Array([[ True,  True, False, False]], dtype=bool)
any(axis=None, out=None, keepdims=False, *, where=None)#

Test whether any of the array elements along a given axis evaluate to True.

JAX implementation of numpy.any().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – Input array.

  • axis (int | Sequence[int] | None) – int or array, default=None. Axis along which to be tested. If None, tests along all the axes.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Array | ndarray | bool | number | bool | int | float | complex | None) – int or array of boolean dtype, default=None. The elements to be used in the test. Array should be broadcast compatible to the input.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of boolean values.

Examples

By default, jnp.any tests along all the axes.

>>> x = jnp.array([[True, True, True, False],
...                [True, False, True, False],
...                [True, True, False, False]])
>>> jnp.any(x)
Array(True, dtype=bool)

If axis=0, tests along axis 0.

>>> jnp.any(x, axis=0)
Array([ True,  True,  True, False], dtype=bool)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.any(x, axis=0, keepdims=True)
Array([[ True,  True,  True, False]], dtype=bool)

To include specific elements in testing for True values, you can use a``where``.

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 1, 0, 1],
...                  [1, 0, 1, 0]], dtype=bool)
>>> jnp.any(x, axis=0, keepdims=True, where=where)
Array([[ True, False,  True, False]], dtype=bool)
argmax(axis=None, out=None, keepdims=None)#

Return the index of the maximum value of an array.

JAX implementation of numpy.argmax().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – input array

  • axis (int | None) – optional integer specifying the axis along which to find the maximum value. If axis is not specified, a will be flattened.

  • out (None) – unused by JAX

  • keepdims (bool | None) – if True, then return an array with the same number of dimensions as a.

Return type:

Array

Returns:

an array containing the index of the maximum value along the specified axis.

See also

  • jax.numpy.argmin(): return the index of the minimum value.

  • jax.numpy.nanargmax(): compute argmax while ignoring NaN values.

Note

When the maximum value occurs more than once along a particular axis, the smallest index is returned.

Examples

>>> x = jnp.array([1, 3, 5, 4, 2])
>>> jnp.argmax(x)
Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, 2],
...                [5, 4, 1]])
>>> jnp.argmax(x, axis=1)
Array([1, 0], dtype=int32)
>>> jnp.argmax(x, axis=1, keepdims=True)
Array([[1],
       [0]], dtype=int32)
argmin(axis=None, out=None, keepdims=None)#

Return the index of the minimum value of an array.

JAX implementation of numpy.argmin().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – input array

  • axis (int | None) – optional integer specifying the axis along which to find the minimum value. If axis is not specified, a will be flattened.

  • out (None) – unused by JAX

  • keepdims (bool | None) – if True, then return an array with the same number of dimensions as a.

Return type:

Array

Returns:

an array containing the index of the minimum value along the specified axis.

Note

When the minimum value occurs more than once along a particular axis, the smallest index is returned.

See also

  • jax.numpy.argmax(): return the index of the maximum value.

  • jax.numpy.nanargmin(): compute argmin while ignoring NaN values.

Examples

>>> x = jnp.array([1, 3, 5, 4, 2])
>>> jnp.argmin(x)
Array(0, dtype=int32)
>>> x = jnp.array([[1, 3, 2],
...                [5, 4, 1]])
>>> jnp.argmin(x, axis=1)
Array([0, 2], dtype=int32)
>>> jnp.argmin(x, axis=1, keepdims=True)
Array([[0],
       [2]], dtype=int32)
argsort(axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False, dtype: str | type[Any] | dtype | SupportsDType | None = None) Array#

Return indices that sort an array.

JAX implementation of numpy.argsort().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – array to sort

  • axis (int | None) – integer axis along which to sort. Defaults to -1, i.e. the last axis. If None, then a is flattened before being sorted.

  • stable (bool) – boolean specifying whether a stable sort should be used. Default=True.

  • descending (bool) – boolean specifying whether to sort in descending order. Default=False.

  • kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.

  • order (None) – not supported by JAX

  • dtype (str | type[Any] | dtype | SupportsDType | None) – optionally specify the dtype of the resulting indices. If not specified, the default integer dtype will be used.

Return type:

Array

Returns:

Array of indices that sort an array. Returned array will be of shape a.shape (if axis is an integer) or of shape (a.size,) (if axis is None).

Examples

Simple 1-dimensional sort

>>> x = jnp.array([1, 3, 5, 4, 2, 1])
>>> indices = jnp.argsort(x)
>>> indices
Array([0, 5, 4, 1, 3, 2], dtype=int32)
>>> x[indices]
Array([1, 1, 2, 3, 4, 5], dtype=int32)

Sort along the last axis of an array:

>>> x = jnp.array([[2, 1, 3],
...                [6, 4, 3]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 0, 2],
       [2, 1, 0]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1, 2, 3],
       [3, 4, 6]], dtype=int32)

See also

  • jax.numpy.sort(): return sorted values directly.

  • jax.numpy.lexsort(): lexicographical sort of multiple arrays.

  • jax.lax.sort(): lower-level function wrapping XLA’s Sort operator.

astype(dtype)[source]#

Return a copy of this quantity with the mantissa cast to dtype.

Parameters:

dtype (str | type[Any] | dtype | SupportsDType) – Target data type (e.g. jnp.float64).

Returns:

A new quantity with the converted dtype.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.astype(jnp.float64).dtype
float64
property at#

Helper property for index update functionality.

The at property provides a functionally pure equivalent of in-place array modifications.

In particular:

Alternate syntax

Equivalent In-place expression

x = x.at[idx].set(y)

x[idx] = y

x = x.at[idx].add(y)

x[idx] += y

x = x.at[idx].multiply(y)

x[idx] *= y

x = x.at[idx].divide(y)

x[idx] /= y

x = x.at[idx].power(y)

x[idx] **= y

x = x.at[idx].min(y)

x[idx] = minimum(x[idx], y)

x = x.at[idx].max(y)

x[idx] = maximum(x[idx], y)

x = x.at[idx].apply(ufunc)

ufunc.at(x, idx)

x = x.at[idx].get()

x = x[idx]

None of the x.at expressions modify the original x; instead they return a modified copy of x. However, inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.

Unlike NumPy in-place operations such as x[idx] += y, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).

By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the mode parameter (see below).

Parameters:
  • mode (str) –

    Specify out-of-bound indexing mode. Options are:

    • "promise_in_bounds": (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that out-of-bounds indices in get() will be clipped, and out-of-bounds indices in set(), add(), etc. will be dropped.

    • "clip": clamp out of bounds indices into valid range.

    • "drop": ignore out-of-bound indices.

    • "fill": alias for "drop". For get(), the optional fill_value argument specifies the value that will be returned.

  • indices_are_sorted (bool) – If True, the implementation will assume that the indices passed to at[] are sorted in ascending order, which can lead to more efficient execution on some backends.

  • unique_indices (bool) – If True, the implementation will assume that the indices passed to at[] are unique, which can result in more efficient execution on some backends.

  • fill_value (Any) – Only applies to the get() method: the fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> x = jnp.arange(5.0) * u.mV
>>> x.at[2].add(10 * u.mV)
Quantity([ 0.  1. 12.  3.  4.], "mV")
>>> x.at[2].get()
Quantity(2., "mV")
clip(min=None, max=None)[source]#

Clip (limit) the values in the array to [min, max].

At least one of min or max must be given. Both must be compatible with the unit of self.

Parameters:
Returns:

The clipped quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.clip(min=u.Quantity(1.5, unit=u.mV), max=u.Quantity(2.5, unit=u.mV))
Quantity([1.5 2.  2.5], "mV")
clone()[source]#

Return a copy of this quantity (PyTorch-style alias for copy()).

Returns:

An independent copy.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q.clone()
Quantity(3., "mV")
conj()[source]#

Return the complex conjugate, element-wise, preserving units.

Returns:

The conjugated quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(1.0 + 2.0j, unit=u.mV)
>>> q.conj()
Quantity((1-2j), "mV")
conjugate()[source]#

Return the complex conjugate, element-wise.

Alias for conj().

Returns:

The conjugated quantity.

Return type:

Quantity

copy()[source]#

Return a deep copy of this quantity.

Returns:

An independent copy with the same mantissa and unit.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q2 = q.copy()
>>> q2
Quantity(3., "mV")
cross(b, axisa=-1, axisb=-1, axisc=-1, axis=None)[source]#

Cross product of two arrays.

The resulting unit is self.unit * b.unit.

Parameters:
  • b (Quantity) – Second operand.

  • axisa (int) – Axis of self that defines the vector(s) (default -1).

  • axisb (int) – Axis of b that defines the vector(s) (default -1).

  • axisc (int) – Axis of the result containing the cross product (default -1).

  • axis (int) – Overrides axisa, axisb, and axisc simultaneously.

Returns:

The cross product.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = u.Quantity(jnp.array([1.0, 0.0, 0.0]), unit=u.mV)
>>> b = u.Quantity(jnp.array([0.0, 1.0, 0.0]), unit=u.second)
>>> a.cross(b)
Quantity([0. 0. 1.], "mV * s")
cumprod(*args, **kwds)[source]#

Return the cumulative product of elements along a given axis.

Because each position in the result corresponds to a different number of multiplied elements, the unit exponent varies across the output. This is only representable when the quantity is dimensionless.

Returns:

The cumulative product.

Return type:

Quantity

Raises:

TypeError – If the quantity is not dimensionless.

cumsum(axis=None, dtype=None, out=None)#

Cumulative sum of elements along an axis.

JAX implementation of numpy.cumsum().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – N-dimensional array to be accumulated.

  • axis (int | None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.

  • out (None) – unused by JAX

Return type:

Array

Returns:

An array containing the accumulated sum along the given axis.

See also

  • jax.numpy.cumulative_sum(): cumulative sum via the array API standard.

  • jax.numpy.add.accumulate(): cumulative sum via ufunc methods.

  • jax.numpy.nancumsum(): cumulative sum ignoring NaN values.

  • jax.numpy.sum(): sum along axis

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.cumsum(x)  # flattened cumulative sum
Array([ 1,  3,  6, 10, 15, 21], dtype=int32)
>>> jnp.cumsum(x, axis=1)  # cumulative sum along axis 1
Array([[ 1,  3,  6],
       [ 4,  9, 15]], dtype=int32)
diagonal(offset=0, axis1=0, axis2=1)[source]#

Return specified diagonals, preserving units.

Parameters:
  • offset (int) – Offset from the main diagonal (default 0).

  • axis1 (int) – First axis (default 0).

  • axis2 (int) – Second axis (default 1).

Returns:

The diagonal elements.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.diagonal()
Quantity([1. 4.], "mV")
property dim: Dimension#

The physical dimension of this quantity (e.g. length, mass, time).

The dimension is independent of scale (metres vs kilometres both have the length dimension).

Returns:

The physical dimension object.

Return type:

Dimension

Examples

>>> import saiunit as u
>>> q = u.Quantity(5.0, unit=u.metre)
>>> q.dim
m

See also

unit

The full unit (dimension + scale).

dot(b)[source]#

Dot product of two arrays.

The resulting unit is self.unit * b.unit.

Parameters:

b (Quantity or array_like) – Second operand.

Returns:

The dot product.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> b = u.Quantity(jnp.array([1.0, 1.0, 1.0]), unit=u.mV)
>>> a.dot(b)
Quantity(6., "mV^2")
property dtype#

The data type of the mantissa.

Returns:

The JAX/NumPy dtype of the underlying array.

Return type:

dtype

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.dtype
float32
expand_as(array)[source]#

Expand an array to a shape of another array.

Parameters:

array (Quantity | Array | ndarray | bool | number | bool | int | float | complex)

Returns:

expanded – A readonly view on the original array with the given shape of array. It is typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location.

Return type:

Quantity

expand_dims(axis)[source]#

Insert new axes at the given positions.

Parameters:

axis (int | Sequence[int]) – Position(s) where the new axis (axes) are placed.

Returns:

The expanded quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.expand_dims(0).shape
(1, 2)
factorless()[source]#

Return an equivalent quantity whose unit has factor == 1.0.

If the unit already has no extra factor the original object is returned unchanged.

Returns:

A quantity with the factor folded into the mantissa.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q.factorless()
Quantity(3., "mV")
fill(value)[source]#

Fill the array with a scalar mantissa.

Return type:

Quantity

property flat#

1-D iterator over the mantissa elements, unit preserved.

flatten()[source]#

Return a 1-D copy of this quantity.

Returns:

Flattened quantity with the same unit.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.flatten()
Quantity([1. 2. 3. 4.], "mV")
has_same_unit(other)[source]#

Check whether this quantity shares the same physical dimension as other.

Two quantities that differ only in scale (e.g. mV vs V) are considered to have the same unit dimension.

Parameters:

other (Quantity or Unit) – The object to compare with.

Returns:

True if both have identical physical dimensions.

Return type:

bool

Examples

>>> import saiunit as u
>>> a = u.Quantity(1.0, unit=u.mV)
>>> b = u.Quantity(2.0, unit=u.volt)
>>> a.has_same_unit(b)
True
>>> c = u.Quantity(1.0, unit=u.second)
>>> a.has_same_unit(c)
False
in_unit(unit, err_msg=None)[source]#

Convert this quantity to a compatible unit.

Behaves identically to to(); kept for API compatibility.

Parameters:
  • unit (Unit) – Target unit. Must share the same dimension as self.unit.

  • err_msg (str) – Custom error message used when the dimensions do not match.

Returns:

A new Quantity expressed in unit.

Return type:

Quantity

Raises:

UnitMismatchError – If unit has a different dimension.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.in_unit(u.volt)
Quantity([0.001 0.002 0.003], "V")
property is_unitless: bool#

True if this quantity is dimensionless (has no physical unit).

Returns:

Whether the quantity is unitless.

Return type:

bool

Examples

>>> import saiunit as u
>>> u.Quantity(5.0).is_unitless
True
>>> u.Quantity(5.0, unit=u.mV).is_unitless
False
item(*args)[source]#

Extract a single element as a scalar Quantity.

Parameters:

*args (int) – Index into the flat array.

Returns:

A 0-D Quantity containing the selected element.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([10.0, 20.0]), unit=u.mV)
>>> q.item(0)
Quantity(10., "mV")
property itemsize: int#

Length (in bytes) of one array element.

property mT: saiunit.Quantity#

Matrix transpose of the last two dimensions, preserving units.

The array must be at least 2-D.

Returns:

The matrix-transposed quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.mT.shape
(2, 2)
property magnitude: Array | ndarray | bool | number | bool | int | float | complex#

Alias for mantissa.

Returns:

The raw numerical data of this quantity.

Return type:

array_like

Examples

>>> import saiunit as u
>>> q = u.Quantity(5.0, unit=u.metre)
>>> q.magnitude
5.0

See also

mantissa

Primary accessor for the numerical data.

property mantissa: Array | ndarray | bool | number | bool | int | float | complex#

The raw numerical data of this quantity (without the unit).

In scientific notation \(x = a \times 10^{b}\), the mantissa is the coefficient \(a\). For a Quantity, it is the underlying JAX/NumPy array (or Python scalar) that stores the numeric value.

Returns:

The mantissa array or scalar.

Return type:

array_like

Examples

>>> import saiunit as u
>>> q = u.Quantity(3.0, unit=u.mV)
>>> q.mantissa
3.0

See also

magnitude

Alias for mantissa.

unit

The physical unit attached to this quantity.

max(axis=None, out=None, keepdims=False, initial=None, where=None)#

Return the maximum of the array elements along a given axis.

JAX implementation of numpy.max().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – Input array.

  • axis (int | Sequence[int] | None) – int or array, default=None. Axis along which the maximum to be computed. If None, the maximum is computed along all the axes.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • initial (Array | ndarray | bool | number | bool | int | float | complex | None) – int or array, default=None. Initial value for the maximum.

  • where (Array | ndarray | bool | number | bool | int | float | complex | None) – int or array of boolean dtype, default=None. The elements to be used in the maximum. Array should be broadcast compatible to the input. initial must be specified when where is used.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of maximum values along the given axis.

See also

  • jax.numpy.min(): Compute the minimum of array elements along a given axis.

  • jax.numpy.sum(): Compute the sum of array elements along a given axis.

  • jax.numpy.prod(): Compute the product of array elements along a given axis.

Examples

By default, jnp.max computes the maximum of elements along all the axes.

>>> x = jnp.array([[9, 3, 4, 5],
...                [5, 2, 7, 4],
...                [8, 1, 3, 6]])
>>> jnp.max(x)
Array(9, dtype=int32)

If axis=1, the maximum will be computed along axis 1.

>>> jnp.max(x, axis=1)
Array([9, 7, 8], dtype=int32)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.max(x, axis=1, keepdims=True)
Array([[9],
       [7],
       [8]], dtype=int32)

To include only specific elements in computing the maximum, you can use where. It can either have same dimension as input

>>> where=jnp.array([[0, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.max(x, axis=1, keepdims=True, initial=0, where=where)
Array([[4],
       [7],
       [8]], dtype=int32)

or must be broadcast compatible with input.

>>> where = jnp.array([[False],
...                    [False],
...                    [False]])
>>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
mean(axis=None, dtype=None, out=None, keepdims=False, *, where=None)#

Return the mean of array elements along a given axis.

JAX implementation of numpy.mean().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – input array.

  • axis (int | Sequence[int] | None) – optional, int or sequence of ints, default=None. Axis along which the mean to be computed. If None, mean is computed along all the axes.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – The type of the output array. If None (default) then the output dtype will be match the input dtype for floating point inputs, or be set to float32 or float64 for non-floating-point inputs.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Array | ndarray | bool | number | bool | int | float | complex | None) – optional, boolean array, default=None. The elements to be used in the mean. Array should be broadcast compatible to the input.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of the mean along the given axis.

Notes

For inputs of type float16 or bfloat16, the reductions will be performed at float32 precision.

See also

  • jax.numpy.average(): Compute the weighted average of array elements

  • jax.numpy.sum(): Compute the sum of array elements.

Examples

By default, the mean is computed along all the axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [5, 2, 6, 3],
...                [8, 1, 2, 9]])
>>> jnp.mean(x)
Array(3.8333335, dtype=float32)

If axis=1, the mean is computed along axis 1.

>>> jnp.mean(x, axis=1)
Array([2.5, 4. , 5. ], dtype=float32)

If keepdims=True, ndim of the output is equal to that of the input.

>>> jnp.mean(x, axis=1, keepdims=True)
Array([[2.5],
       [4. ],
       [5. ]], dtype=float32)

To use only specific elements of x to compute the mean, you can use where.

>>> where = jnp.array([[1, 0, 1, 0],
...                    [0, 1, 0, 1],
...                    [1, 1, 0, 1]], dtype=bool)
>>> jnp.mean(x, axis=1, keepdims=True, where=where)
Array([[2.5],
       [2.5],
       [6. ]], dtype=float32)
min(axis=None, out=None, keepdims=False, initial=None, where=None)#

Return the minimum of array elements along a given axis.

JAX implementation of numpy.min().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – Input array.

  • axis (int | Sequence[int] | None) – int or array, default=None. Axis along which the minimum to be computed. If None, the minimum is computed along all the axes.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • initial (Array | ndarray | bool | number | bool | int | float | complex | None) – int or array, Default=None. Initial value for the minimum.

  • where (Array | ndarray | bool | number | bool | int | float | complex | None) – int or array, default=None. The elements to be used in the minimum. Array should be broadcast compatible to the input. initial must be specified when where is used.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of minimum values along the given axis.

See also

  • jax.numpy.max(): Compute the maximum of array elements along a given axis.

  • jax.numpy.sum(): Compute the sum of array elements along a given axis.

  • jax.numpy.prod(): Compute the product of array elements along a given axis.

Examples

By default, the minimum is computed along all the axes.

>>> x = jnp.array([[2, 5, 1, 6],
...                [3, -7, -2, 4],
...                [8, -4, 1, -3]])
>>> jnp.min(x)
Array(-7, dtype=int32)

If axis=1, the minimum is computed along axis 1.

>>> jnp.min(x, axis=1)
Array([ 1, -7, -4], dtype=int32)

If keepdims=True, ndim of the output will be same of that of the input.

>>> jnp.min(x, axis=1, keepdims=True)
Array([[ 1],
       [-7],
       [-4]], dtype=int32)

To include only specific elements in computing the minimum, you can use where. where can either have same dimension as input.

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.min(x, axis=1, keepdims=True, initial=0, where=where)
Array([[ 0],
       [-2],
       [-4]], dtype=int32)

or must be broadcast compatible with input.

>>> where = jnp.array([[False],
...                    [False],
...                    [False]])
>>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
nancumprod(*args, **kwds)[source]#

Return the cumulative product of elements along a given axis, treating NaNs as ones.

Because each position in the result corresponds to a different number of multiplied elements, the unit exponent varies across the output. This is only representable when the quantity is dimensionless.

Returns:

The cumulative product (NaNs treated as ones).

Return type:

Quantity

Raises:

TypeError – If the quantity is not dimensionless.

nanprod(*args, **kwds)[source]#

Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones.

When reducing along a specific axis, the number of non-NaN elements must be the same for every position in the result so that a single unit exponent can be assigned. If the non-NaN counts differ and the quantity is not dimensionless, a ValueError is raised.

Returns:

The product (NaNs treated as ones).

Return type:

Quantity

Raises:

ValueError – If the non-NaN counts are not uniform along the reduction axis for a non-dimensionless quantity.

property nbytes: int#

Total bytes consumed by the mantissa array.

nonzero(*, size=None, fill_value=None)#

Return indices of nonzero elements of an array.

JAX implementation of numpy.nonzero().

Because the size of the output of nonzero is data-dependent, the function is not compatible with JIT and other transformations. The JAX version adds the optional size argument which must be specified statically for jnp.nonzero to be used within JAX’s transformations.

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – N-dimensional array.

  • size (int | None) – optional static integer specifying the number of nonzero entries to return. If there are more nonzero elements than the specified size, then indices will be truncated at the end. If there are fewer nonzero elements than the specified size, then indices will be padded with fill_value, which defaults to zero.

  • fill_value (None | Array | ndarray | bool | number | bool | int | float | complex | tuple[Array | ndarray | bool | number | bool | int | float | complex, ...]) – optional padding value when size is specified. Defaults to 0.

Return type:

tuple[Array, ...]

Returns:

Tuple of JAX Arrays of length a.ndim, containing the indices of each nonzero value.

See also

  • jax.numpy.flatnonzero()

  • jax.numpy.where()

Examples

One-dimensional array returns a length-1 tuple of indices:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jnp.nonzero(x)
(Array([1, 3, 5], dtype=int32),)

Two-dimensional array returns a length-2 tuple of indices:

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 7]])
>>> jnp.nonzero(x)
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))

In either case, the resulting tuple of indices can be used directly to extract the nonzero values:

>>> indices = jnp.nonzero(x)
>>> x[indices]
Array([5, 6, 7], dtype=int32)

The output of nonzero has a dynamic shape, because the number of returned indices depends on the contents of the input array. As such, it is incompatible with JIT and other JAX transformations:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jax.jit(jnp.nonzero)(x)
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.

This can be addressed by passing a static size parameter to specify the desired output shape:

>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size')
>>> nonzero_jit(x, size=3)
(Array([1, 3, 5], dtype=int32),)

If size does not match the true size, the result will be either truncated or padded:

>>> nonzero_jit(x, size=2)  # size < 3: indices are truncated
(Array([1, 3], dtype=int32),)
>>> nonzero_jit(x, size=5)  # size > 3: indices are padded with zeros.
(Array([1, 3, 5, 0, 0], dtype=int32),)

You can specify a custom fill value for the padding using the fill_value argument:

>>> nonzero_jit(x, size=5, fill_value=len(x))
(Array([1, 3, 5, 6, 6], dtype=int32),)
outer(b)[source]#

Outer product of two 1-D arrays.

The resulting unit is self.unit * b.unit.

Parameters:

b (Quantity) – Second operand.

Returns:

The outer product matrix.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> b = u.Quantity(jnp.array([3.0, 4.0]), unit=u.second)
>>> a.outer(b).shape
(2, 2)
pow(oc)[source]#

Raise this quantity to the power oc.

The exponent must be dimensionless. The resulting unit is self.unit ** oc.

Parameters:

oc (int, float, or dimensionless Quantity) – The exponent.

Returns:

self ** oc.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(2.0, unit=u.mV)
>>> q.pow(2)
Quantity(4., "mV^2")
prod(*args, **kwds)[source]#

Return the product of array elements over the given axis.

The unit of the result is self.unit ** n where n is the number of elements multiplied together.

Returns:

The product.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([2.0, 3.0]), unit=u.mV)
>>> q.prod()
Quantity(6., "mV^2")
ptp(axis=None, out=None, keepdims=False)#

Return the peak-to-peak range along a given axis.

JAX implementation of numpy.ptp().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – input array.

  • axis (int | Sequence[int] | None) – optional, int or sequence of ints, default=None. Axis along which the range is computed. If None, the range is computed on the flattened array.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array with the range of elements along specified axis of input.

Examples

By default, jnp.ptp computes the range along all axes.

>>> x = jnp.array([[1, 3, 5, 2],
...                [4, 6, 8, 1],
...                [7, 9, 3, 4]])
>>> jnp.ptp(x)
Array(8, dtype=int32)

If axis=1, computes the range along axis 1.

>>> jnp.ptp(x, axis=1)
Array([4, 7, 6], dtype=int32)

To preserve the dimensions of input, you can set keepdims=True.

>>> jnp.ptp(x, axis=1, keepdims=True)
Array([[4],
       [7],
       [6]], dtype=int32)
put(indices, values)[source]#

Replaces specified elements of an array with given values.

Parameters:
  • indices (array_like) – Target indices, interpreted as integers.

  • values (array_like) – Values to place in the array at target indices.

Return type:

Quantity

ravel(order='C', *, out_sharding=None)#

Flatten array into a 1-dimensional shape.

JAX implementation of numpy.ravel(), implemented in terms of jax.lax.reshape().

ravel(arr, order=order) is equivalent to reshape(arr, -1, order=order).

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – array to be flattened.

  • order (str) – 'F' or 'C', specifies whether the reshape should apply column-major (fortran-style, "F") or row-major (C-style, "C") order; default is "C". JAX does not support order=”A” or order=”K”.

Return type:

Array

Returns:

flattened copy of input array.

Notes

Unlike numpy.ravel(), jax.numpy.ravel() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

See also

  • jax.Array.ravel(): equivalent functionality via an array method.

  • jax.numpy.reshape(): general array reshape.

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])

By default, ravel in C-style, row-major order

>>> jnp.ravel(x)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

Optionally ravel in Fortran-style, column-major:

>>> jnp.ravel(x, order='F')
Array([1, 4, 2, 5, 3, 6], dtype=int32)

For convenience, the same functionality is available via the jax.Array.ravel() method:

>>> x.ravel()
Array([1, 2, 3, 4, 5, 6], dtype=int32)
repeat(repeats, axis=None)[source]#

Repeat elements of the array.

Parameters:
  • repeats (int or array of ints) – Number of repetitions for each element.

  • axis (int, optional) – Axis along which to repeat.

Returns:

The repeated quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.repeat(2)
Quantity([1. 1. 2. 2.], "mV")
repr_in_unit(precision=None)[source]#

Return a human-readable string of this quantity in its current unit.

The format is "<value> <unit>", e.g. "3. mV" or "[1. 2. 3.] mV".

Parameters:

precision (int | None) – Number of significant digits. When None the value from numpy.get_printoptions is used.

Returns:

The formatted string.

Return type:

str

Examples

>>> import saiunit as u
>>> x = u.Quantity(25.0, unit=u.mV)
>>> x.repr_in_unit()
'25. mV'
>>> x.to(u.volt).repr_in_unit(3)
'0.025 V'
reshape(shape, order='C')[source]#

Return a quantity with the same data but a new shape.

Parameters:
  • shape (int or tuple of ints) – New shape.

  • order ({'C', 'F'}, optional) – Memory layout order (default 'C').

Returns:

Reshaped quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.reshape((3, 1)).shape
(3, 1)
resize(new_shape)[source]#

Change shape and size of array in-place.

Return type:

Quantity

round(decimals=0)[source]#

Evenly round the mantissa to the given number of decimals.

Parameters:

decimals (int) – Number of decimal places (default 0). Negative values round to positions left of the decimal point.

Returns:

A new quantity with the rounded mantissa.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> q = u.Quantity(1.567, unit=u.mV)
>>> q.round(1)
Quantity(1.6, "mV")
scatter_add(index, value)[source]#

Return a copy with value added at index.

Parameters:
  • index (Array | ndarray | bool | number | bool | int | float | complex) – Target index (indices).

  • value (Quantity | Array | ndarray | bool | number | bool | int | float | complex) – The value to add. Must have the same unit dimension.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_add(0, u.Quantity(10.0, unit=u.mV))
Quantity([11.  2.  3.], "mV")
scatter_div(index, value)[source]#

Return a copy with the element at index divided by value.

value must be dimensionless (a pure scale factor).

Parameters:
Returns:

A new quantity with the update applied.

Return type:

Quantity

Raises:

TypeError – If value is not dimensionless.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_div(0, u.Quantity(2.0))
Quantity([0.5 2.  3. ], "mV")
scatter_max(index, value)[source]#

Return a copy where the element at index is the maximum of the current value and value.

Parameters:
  • index (Array | ndarray | bool | number | bool | int | float | complex) – Target index (indices).

  • value (Quantity | Array | ndarray | bool | number | bool | int | float | complex) – The comparison value. Must have the same unit dimension.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_max(0, u.Quantity(10.0, unit=u.mV))
Quantity([10.  2.  3.], "mV")
scatter_min(index, value)[source]#

Return a copy where the element at index is the minimum of the current value and value.

Parameters:
  • index (Array | ndarray | bool | number | bool | int | float | complex) – Target index (indices).

  • value (Quantity | Array | ndarray | bool | number | bool | int | float | complex) – The comparison value. Must have the same unit dimension.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_min(0, u.Quantity(0.5, unit=u.mV))
Quantity([0.5 2.  3. ], "mV")
scatter_mul(index, value)[source]#

Return a copy with the element at index multiplied by value.

value must be dimensionless (a pure scale factor).

Parameters:
Returns:

A new quantity with the update applied.

Return type:

Quantity

Raises:

TypeError – If value is not dimensionless.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_mul(0, u.Quantity(10.0))
Quantity([10.  2.  3.], "mV")
scatter_sub(index, value)[source]#

Return a copy with value subtracted at index.

Parameters:
  • index (Array | ndarray | bool | number | bool | int | float | complex) – Target index (indices).

  • value (Quantity | Array | ndarray | bool | number | bool | int | float | complex) – The value to subtract. Must have the same unit dimension.

Returns:

A new quantity with the update applied.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.scatter_sub(0, u.Quantity(1.0, unit=u.mV))
Quantity([0. 2. 3.], "mV")
searchsorted(v, side='left', sorter=None)[source]#

Find indices where elements should be inserted to maintain order.

Return type:

Array

property shape: tuple[int, ...]#

The shape of the mantissa array.

Returns:

Shape tuple, identical to jnp.shape(self.mantissa).

Return type:

tuple of int

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.shape
(2, 2)
sort(axis=-1, stable=True, order=None)[source]#

Sort the array in-place along the given axis.

Parameters:
  • axis (int, optional) – Axis along which to sort (default -1).

  • stable (bool, optional) – Whether to use a stable sort (default True).

  • order (str or list of str, optional) – Field ordering for structured arrays.

Returns:

self, with the mantissa sorted in-place.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([3.0, 1.0, 2.0]), unit=u.mV)
>>> q.sort()
Quantity([1. 2. 3.], "mV")
split(indices_or_sections, axis=0)[source]#

Split the array into multiple sub-arrays.

Parameters:
  • indices_or_sections (int or 1-D array) – If an integer N, the array is divided into N equal parts. If a sorted 1-D array of indices, the entries indicate split points along axis.

  • axis (int, optional) – Axis along which to split (default 0).

Returns:

Sub-arrays, each carrying the same unit.

Return type:

list[Quantity]

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> parts = q.split(3)
>>> len(parts)
3
squeeze(axis=None)[source]#

Remove length-one axes from the array.

Parameters:

axis (int or tuple of ints, optional) – Axes to remove. If None, all length-one axes are removed.

Returns:

The squeezed quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[[1.0]]]), unit=u.mV)
>>> q.squeeze().shape
()
std(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, mean=None, correction=None)#

Compute the standard deviation along a given axis.

JAX implementation of numpy.std().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – input array.

  • axis (int | Sequence[int] | None) – optional, int or sequence of ints, default=None. Axis along which the standard deviation is computed. If None, standard deviaiton is computed along all the axes.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – The type of the output array. Default=None.

  • ddof (int) – int, default=0. Degrees of freedom. The divisor in the standard deviation computation is N-ddof, N is number of elements along given axis.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Array | ndarray | bool | number | bool | int | float | complex | None) – optional, boolean array, default=None. The elements to be used in the standard deviation. Array should be broadcast compatible to the input.

  • mean (Array | ndarray | bool | number | bool | int | float | complex | None) – optional, mean of the input array, computed along the given axis. If provided, it will be used to compute the standard deviation instead of computing it from the input array. If specified, mean must be broadcast-compatible with the input array. In the general case, this can be achieved by computing the mean with keepdims=True and axis matching this function’s axis argument.

  • correction (int | float | None) – int or float, default=None. Alternative name for ddof. Both ddof and correction can’t be provided simultaneously.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of the standard deviation along the given axis.

See also

  • jax.numpy.var(): Compute the variance of array elements over given axis.

  • jax.numpy.mean(): Compute the mean of array elements over a given axis.

  • jax.numpy.nanvar(): Compute the variance along a given axis, ignoring NaNs values.

  • jax.numpy.nanstd(): Computed the standard deviation of a given axis, ignoring NaN values.

Examples

By default, jnp.std computes the standard deviation along all axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [4, 2, 5, 3],
...                [5, 4, 2, 3]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.std(x)
Array(1.21, dtype=float32)

If axis=0, computes along axis 0.

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0))
[1.7  0.82 1.25 0.47]

To preserve the dimensions of input, you can set keepdims=True.

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0, keepdims=True))
[[1.7  0.82 1.25 0.47]]

If ddof=1:

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0, keepdims=True, ddof=1))
[[2.08 1.   1.53 0.58]]

To include specific elements of the array to compute standard deviation, you can use where.

>>> where = jnp.array([[1, 0, 1, 0],
...                    [0, 1, 0, 1],
...                    [1, 1, 1, 0]], dtype=bool)
>>> jnp.std(x, axis=0, keepdims=True, where=where)
Array([[2., 1., 1., 0.]], dtype=float32)
property strides#

Tuple of byte-steps in each dimension (mirrors numpy.ndarray.strides).

sum(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)#

Sum of the elements of the array over a given axis.

JAX implementation of numpy.sum().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – Input array.

  • axis (int | Sequence[int] | None) – int or array, default=None. Axis along which the sum to be computed. If None, the sum is computed along all the axes.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – The type of the output array. Default=None.

  • out (None) – Unused by JAX

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • initial (Array | ndarray | bool | number | bool | int | float | complex | None) – int or array, Default=None. Initial value for the sum.

  • where (Array | ndarray | bool | number | bool | int | float | complex | None) – int or array, default=None. The elements to be used in the sum. Array should be broadcast compatible to the input.

  • promote_integers (bool) – bool, default=True. If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input. promote_integers is ignored if dtype is specified.

Return type:

Array

Returns:

An array of the sum along the given axis.

See also

  • jax.numpy.prod(): Compute the product of array elements over a given axis.

  • jax.numpy.max(): Compute the maximum of array elements over given axis.

  • jax.numpy.min(): Compute the minimum of array elements over given axis.

Examples

By default, the sum is computed along all the axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [5, 2, 6, 3],
...                [8, 1, 3, 9]])
>>> jnp.sum(x)
Array(47, dtype=int32)

If axis=1, the sum is computed along axis 1.

>>> jnp.sum(x, axis=1)
Array([10, 16, 21], dtype=int32)

If keepdims=True, ndim of the output is equal to that of the input.

>>> jnp.sum(x, axis=1, keepdims=True)
Array([[10],
       [16],
       [21]], dtype=int32)

To include only specific elements in the sum, you can use where.

>>> where=jnp.array([[0, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.sum(x, axis=1, keepdims=True, where=where)
Array([[ 4],
       [ 9],
       [12]], dtype=int32)
>>> where=jnp.array([[False],
...                  [False],
...                  [False]])
>>> jnp.sum(x, axis=0, keepdims=True, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
swapaxes(axis1, axis2)[source]#

Interchange two axes of the array.

Parameters:
  • axis1 (int) – First axis.

  • axis2 (int) – Second axis.

Returns:

The quantity with axes swapped.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.swapaxes(0, 1).shape
(2, 2)
take(indices, axis=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#

Select elements from the array at the given indices.

Parameters:
  • indices (array_like) – Indices of the values to extract.

  • axis (int, optional) – Axis along which to take (default flattened).

  • mode (str, optional) – Out-of-bounds index handling.

  • unique_indices (bool, optional) – Hint that indices are unique.

  • indices_are_sorted (bool, optional) – Hint that indices are sorted.

  • fill_value (Quantity or scalar, optional) – Value for out-of-bounds positions when mode is 'fill'.

Returns:

The selected elements.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([10.0, 20.0, 30.0]), unit=u.mV)
>>> q.take(jnp.array([0, 2]))
Quantity([10. 30.], "mV")
tile(reps)[source]#

Construct an array by repeating this quantity.

Parameters:

reps (int or array_like) – Number of repetitions along each axis.

Returns:

The tiled quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.tile(2)
Quantity([1. 2. 1. 2.], "mV")
to(new_unit)[source]#

Convert this quantity to a different (compatible) unit.

The mantissa is rescaled so that the physical value stays the same, and the returned Quantity carries new_unit.

Parameters:

new_unit (Unit) – Target unit. Must have the same dimension as self.unit.

Returns:

A new Quantity expressed in new_unit.

Return type:

Quantity

Raises:

UnitMismatchError – If new_unit has a different dimension.

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.to(u.volt)
Quantity([0.001 0.002 0.003], "V")

See also

in_unit

Identical behaviour (to delegates to in_unit).

to_decimal

Convert to a plain number in the target unit.

to_decimal(unit=Unit('1'))[source]#

Return the numerical value expressed in the given unit, without wrapping the result in a Quantity.

This is useful when you need a plain JAX array for downstream computation that does not support units.

Parameters:

unit (Unit) – The reference unit. Defaults to UNITLESS.

Returns:

A plain number or JAX array representing the quantity in unit.

Return type:

Array | ndarray | bool | number | bool | int | float | complex

Raises:

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.to_decimal(u.volt)
Array([0.001, 0.002, 0.003], dtype=float32)

See also

to

Convert while keeping the Quantity wrapper.

tolist()[source]#

Convert the array to a (nested) Python list of Quantity scalars.

Each leaf element is a 0-D Quantity with the same unit.

Returns:

A nested list of scalar Quantity objects, or a single Quantity for 0-D arrays.

Return type:

list or Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.tolist()
[Quantity(1., "mV"), Quantity(2., "mV")]
trace(offset=0, axis1=0, axis2=1)[source]#

Sum along diagonals of the array, preserving units.

Parameters:
  • offset (int) – Offset of the diagonal from the main diagonal (default 0).

  • axis1 (int) – First axis of the 2-D sub-arrays (default 0).

  • axis2 (int) – Second axis of the 2-D sub-arrays (default 1).

Returns:

The trace value(s).

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.eye(3), unit=u.mV)
>>> q.trace()
Quantity(3., "mV")
transpose(*axes)[source]#

Return the array with axes transposed.

For a 2-D array this is the standard matrix transpose.

Parameters:

*axes (None, tuple of ints, or n ints) – If omitted, axes are reversed. Otherwise specifies the permutation.

Returns:

Transposed quantity.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV)
>>> q.transpose().shape
(2, 2)
tree_flatten()[source]#

Tree flattens the data.

Return type:

tuple[tuple[Array | ndarray | bool | number | bool | int | float | complex], Unit]

Returns:

The data and the dimension.

classmethod tree_unflatten(unit, values)[source]#

Tree unflattens the data.

Parameters:
  • unit – The unit.

  • values – The data.

Return type:

Quantity

Returns:

The Quantity object.

property unit: saiunit.Unit#

The Unit attached to this quantity.

The unit carries both the physical dimension and the scale factor (e.g. mV has dimension voltage with scale 1e-3).

Returns:

The unit of this quantity.

Return type:

Unit

Examples

>>> import saiunit as u
>>> q = u.Quantity(5.0, unit=u.mV)
>>> q.unit
mV

See also

dim

The physical dimension without scale information.

mantissa

The numerical value.

unsqueeze(axis)[source]#

Insert a length-one axis (PyTorch-style alias for expand_dims()).

Parameters:

axis (int) – Position where the new axis is inserted.

Returns:

The quantity with an extra dimension.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV)
>>> q.unsqueeze(0).shape
(1, 2)
update_mantissa(mantissa)[source]#

Replace the mantissa in-place, keeping the same unit.

The new mantissa must have the same shape and dtype as the current one.

Parameters:

mantissa (Any) – The new numerical data. Must not be a Quantity.

Raises:

ValueError – If mantissa is a Quantity, or if shape/dtype do not match.

Return type:

None

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV)
>>> q.update_mantissa(jnp.array([4.0, 5.0, 6.0]))
>>> q
Quantity([4. 5. 6.], "mV")
var(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, mean=None, correction=None)#

Compute the variance along a given axis.

JAX implementation of numpy.var().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – input array.

  • axis (int | Sequence[int] | None) – optional, int or sequence of ints, default=None. Axis along which the variance is computed. If None, variance is computed along all the axes.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – The type of the output array. Default=None.

  • ddof (int) – int, default=0. Degrees of freedom. The divisor in the variance computation is N-ddof, N is number of elements along given axis.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (Array | ndarray | bool | number | bool | int | float | complex | None) – optional, boolean array, default=None. The elements to be used in the variance. Array should be broadcast compatible to the input.

  • mean (Array | ndarray | bool | number | bool | int | float | complex | None) – optional, mean of the input array, computed along the given axis. If provided, it will be used to compute the variance instead of computing it from the input array. If specified, mean must be broadcast-compatible with the input array. In the general case, this can be achieved by computing the mean with keepdims=True and axis matching this function’s axis argument.

  • correction (int | float | None) – int or float, default=None. Alternative name for ddof. Both ddof and correction can’t be provided simultaneously.

  • out (None) – Unused by JAX.

Return type:

Array

Returns:

An array of the variance along the given axis.

See also

  • jax.numpy.mean(): Compute the mean of array elements over a given axis.

  • jax.numpy.std(): Compute the standard deviation of array elements over given axis.

  • jax.numpy.nanvar(): Compute the variance along a given axis, ignoring NaNs values.

  • jax.numpy.nanstd(): Computed the standard deviation of a given axis, ignoring NaN values.

Examples

By default, jnp.var computes the variance along all axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [5, 2, 6, 3],
...                [8, 4, 2, 9]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.var(x)
Array(5.74, dtype=float32)

If axis=1, variance is computed along axis 1.

>>> jnp.var(x, axis=1)
Array([1.25  , 2.5   , 8.1875], dtype=float32)

To preserve the dimensions of input, you can set keepdims=True.

>>> jnp.var(x, axis=1, keepdims=True)
Array([[1.25  ],
       [2.5   ],
       [8.1875]], dtype=float32)

If ddof=1:

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.var(x, axis=1, keepdims=True, ddof=1))
[[ 1.67]
 [ 3.33]
 [10.92]]

To include specific elements of the array to compute variance, you can use where.

>>> where = jnp.array([[1, 0, 1, 0],
...                    [0, 1, 1, 0],
...                    [1, 1, 1, 0]], dtype=bool)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.var(x, axis=1, keepdims=True, where=where))
[[2.25]
 [4.  ]
 [6.22]]
view(*args, dtype=None)[source]#

New view of array with the same data.

This function is compatible with pytorch syntax.

Returns a new tensor with the same data as the self tensor but of a different shape.

The returned tensor shares the same data and must have the same number of elements, but may have a different size. For a tensor to be viewed, the new view size must be compatible with its original size and stride, i.e., each new view dimension must either be a subspace of an original dimension, or only span across original dimensions \(d, d+1, \dots, d+k\) that satisfy the following contiguity-like condition that \(\forall i = d, \dots, d+k-1\),

\[\text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1]\]

Otherwise, it will not be possible to view self tensor as shape without copying it (e.g., via contiguous()). When it is unclear whether a view() can be performed, it is advisable to use reshape(), which returns a view if the shapes are compatible, and copies (equivalent to calling contiguous()) otherwise.

Parameters:

shape (int...) – the desired size

Return type:

Quantity

Example:

>>> import jax.numpy as jnp, saiunit
>>> x = saiunit.Quantity(jnp.ones((4, 4)))
>>> x.shape
(4, 4)
>>> y = x.view(16)
>>> y.shape
(16,)
>>> z = x.view(2, 8)
>>> z.shape
(2, 8)
view(dtype) Tensor[source]

Returns a new tensor with the same data as the self tensor but of a different dtype.

If the element size of dtype is different than that of self.dtype, then the size of the last dimension of the output will be scaled proportionally. For instance, if dtype element size is twice that of self.dtype, then each pair of elements in the last dimension of self will be combined, and the size of the last dimension of the output will be half that of self. If dtype element size is half that of self.dtype, then each element in the last dimension of self will be split in two, and the size of the last dimension of the output will be double that of self. For this to be possible, the following conditions must be true:

  • self.dim() must be greater than 0.

  • self.stride(-1) must be 1.

Additionally, if the element size of dtype is greater than that of self.dtype, the following conditions must be true as well:

  • self.size(-1) must be divisible by the ratio between the element sizes of the dtypes.

  • self.storage_offset() must be divisible by the ratio between the element sizes of the dtypes.

  • The strides of all dimensions, except the last dimension, must be divisible by the ratio between the element sizes of the dtypes.

If any of the above conditions are not met, an error is thrown.

Parameters:

dtype (dtype) – the desired dtype

Example:

>>> x = brainstate.random.randn(4, 4)
>>> x
Array([[ 0.9482, -0.0310,  1.4999, -0.5316],
        [-0.1520,  0.7472,  0.5617, -0.8649],
        [-2.4724, -0.0334, -0.2976, -0.8499],
        [-0.2109,  1.9913, -0.9607, -0.6123]])
>>> x.dtype
brainstate.math.float32

>>> y = x.view(numpy.int32)
>>> y
tensor([[ 1064483442, -1124191867,  1069546515, -1089989247],
        [-1105482831,  1061112040,  1057999968, -1084397505],
        [-1071760287, -1123489973, -1097310419, -1084649136],
        [-1101533110,  1073668768, -1082790149, -1088634448]],
    dtype=numpy.int32)
>>> y[0, 0] = 1000000000
>>> x
tensor([[ 0.0047, -0.0310,  1.4999, -0.5316],
        [-0.1520,  0.7472,  0.5617, -0.8649],
        [-2.4724, -0.0334, -0.2976, -0.8499],
        [-0.2109,  1.9913, -0.9607, -0.6123]])

>>> x.view(numpy.complex64)
tensor([[ 0.0047-0.0310j,  1.4999-0.5316j],
        [-0.1520+0.7472j,  0.5617-0.8649j],
        [-2.4724-0.0334j, -0.2976-0.8499j],
        [-0.2109+1.9913j, -0.9607-0.6123j]])
>>> x.view(numpy.complex64).size
[4, 2]

>>> x.view(numpy.uint8)
tensor([[  0, 202, 154,  59, 182, 243, 253, 188, 185, 252, 191,  63, 240,  22,
             8, 191],
        [227, 165,  27, 190, 128,  72,  63,  63, 146, 203,  15,  63,  22, 106,
            93, 191],
        [205,  59,  30, 192, 112, 206,   8, 189,   7,  95, 152, 190,  12, 147,
            89, 191],
        [ 43, 246,  87, 190, 235, 226, 254,  63, 111, 240, 117, 191, 177, 191,
            28, 191]], dtype=uint8)
>>> x.view(numpy.uint8).size
[4, 16]
static with_unit(mantissa, unit)[source]#

Create a Quantity from a raw value and a unit.

This is a convenience factory that reads more naturally in some contexts than the standard constructor.

Parameters:
  • mantissa (Any) – The numerical value(s).

  • unit (Unit) – The physical unit.

Returns:

A new Quantity with the given mantissa and unit.

Return type:

Quantity

Examples

>>> import saiunit as u
>>> u.Quantity.with_unit(2.0, unit=u.metre)
Quantity(2., "m")