Quantity#
- class saiunit.Quantity(mantissa, unit=Unit('1'), dtype=None)#
A numerical value paired with a physical unit.
Quantityis the central data structure insaiunit. It stores a mantissa (the raw numerical data, typically a JAX array) together with aUnitthat describes the physical dimensions and scale. Arithmetic onQuantityobjects automatically tracks and checks units, raisingUnitMismatchErrorwhen incompatible quantities are combined.Quantityis registered as a JAX pytree, so it works transparently withjax.jit,jax.grad,jax.vmap, and other JAX transformations.- Parameters:
mantissa (
Any|Unit) – The numerical value(s). If aUnitis passed, the mantissa is set to1.0and that unit is adopted. If aQuantityis passed, its mantissa and unit are used (converted to unit when given).unit (saiunit.Unit |
Array|ndarray|bool|number|bool|int|float|complex|str|None) – The physical unit. Defaults toUNITLESS.dtype (
str|type[Any] |dtype|SupportsDType|None) – If provided, the mantissa is cast to this dtype on construction.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> # Scalar with unit >>> q = u.Quantity(3.0, unit=u.mV) >>> q Quantity(3., "mV") >>> # Array with unit via multiplication shorthand >>> arr = jnp.array([1.0, 2.0, 3.0]) * u.mV >>> arr.shape (3,) >>> # From a Unit object directly >>> u.Quantity(u.metre) Quantity(1., "m")
See also
UnitRepresents a physical unit (dimension + scale).
compatible_with_equinoxToggle Equinox interoperability.
- all(axis=None, out=None, keepdims=False, *, where=None)#
Test whether all array elements along a given axis evaluate to True.
JAX implementation of
numpy.all().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – Input array.axis (
int|Sequence[int] |None) – int or array, default=None. Axis along which to be tested. If None, tests along all the axes.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Array|ndarray|bool|number|bool|int|float|complex|None) – int or array of boolean dtype, default=None. The elements to be used in the test. Array should be broadcast compatible to the input.out (
None) – Unused by JAX.
- Return type:
Array- Returns:
An array of boolean values.
Examples
By default,
jnp.alltests for True values along all the axes.>>> x = jnp.array([[True, True, True, False], ... [True, False, True, False], ... [True, True, False, False]]) >>> jnp.all(x) Array(False, dtype=bool)
If
axis=0, tests for True values along axis 0.>>> jnp.all(x, axis=0) Array([ True, False, False, False], dtype=bool)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.all(x, axis=0, keepdims=True) Array([[ True, False, False, False]], dtype=bool)
To include specific elements in testing for True values, you can use a``where``.
>>> where=jnp.array([[1, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.all(x, axis=0, keepdims=True, where=where) Array([[ True, True, False, False]], dtype=bool)
- any(axis=None, out=None, keepdims=False, *, where=None)#
Test whether any of the array elements along a given axis evaluate to True.
JAX implementation of
numpy.any().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – Input array.axis (
int|Sequence[int] |None) – int or array, default=None. Axis along which to be tested. If None, tests along all the axes.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Array|ndarray|bool|number|bool|int|float|complex|None) – int or array of boolean dtype, default=None. The elements to be used in the test. Array should be broadcast compatible to the input.out (
None) – Unused by JAX.
- Return type:
Array- Returns:
An array of boolean values.
Examples
By default,
jnp.anytests along all the axes.>>> x = jnp.array([[True, True, True, False], ... [True, False, True, False], ... [True, True, False, False]]) >>> jnp.any(x) Array(True, dtype=bool)
If
axis=0, tests along axis 0.>>> jnp.any(x, axis=0) Array([ True, True, True, False], dtype=bool)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.any(x, axis=0, keepdims=True) Array([[ True, True, True, False]], dtype=bool)
To include specific elements in testing for True values, you can use a``where``.
>>> where=jnp.array([[1, 0, 1, 0], ... [0, 1, 0, 1], ... [1, 0, 1, 0]], dtype=bool) >>> jnp.any(x, axis=0, keepdims=True, where=where) Array([[ True, False, True, False]], dtype=bool)
- argmax(axis=None, out=None, keepdims=None)#
Return the index of the maximum value of an array.
JAX implementation of
numpy.argmax().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – input arrayaxis (
int|None) – optional integer specifying the axis along which to find the maximum value. Ifaxisis not specified,awill be flattened.out (
None) – unused by JAXkeepdims (
bool|None) – if True, then return an array with the same number of dimensions asa.
- Return type:
Array- Returns:
an array containing the index of the maximum value along the specified axis.
See also
jax.numpy.argmin(): return the index of the minimum value.jax.numpy.nanargmax(): computeargmaxwhile ignoring NaN values.
Note
When the maximum value occurs more than once along a particular axis, the smallest index is returned.
Examples
>>> x = jnp.array([1, 3, 5, 4, 2]) >>> jnp.argmax(x) Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) >>> jnp.argmax(x, axis=1) Array([1, 0], dtype=int32)
>>> jnp.argmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)
- argmin(axis=None, out=None, keepdims=None)#
Return the index of the minimum value of an array.
JAX implementation of
numpy.argmin().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – input arrayaxis (
int|None) – optional integer specifying the axis along which to find the minimum value. Ifaxisis not specified,awill be flattened.out (
None) – unused by JAXkeepdims (
bool|None) – if True, then return an array with the same number of dimensions asa.
- Return type:
Array- Returns:
an array containing the index of the minimum value along the specified axis.
Note
When the minimum value occurs more than once along a particular axis, the smallest index is returned.
See also
jax.numpy.argmax(): return the index of the maximum value.jax.numpy.nanargmin(): computeargminwhile ignoring NaN values.
Examples
>>> x = jnp.array([1, 3, 5, 4, 2]) >>> jnp.argmin(x) Array(0, dtype=int32)
>>> x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) >>> jnp.argmin(x, axis=1) Array([0, 2], dtype=int32)
>>> jnp.argmin(x, axis=1, keepdims=True) Array([[0], [2]], dtype=int32)
- argsort(axis: int | None = -1, *, kind: None = None, order: None = None, stable: bool = True, descending: bool = False, dtype: str | type[Any] | dtype | SupportsDType | None = None) Array#
Return indices that sort an array.
JAX implementation of
numpy.argsort().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – array to sortaxis (
int|None) – integer axis along which to sort. Defaults to-1, i.e. the last axis. IfNone, thenais flattened before being sorted.stable (
bool) – boolean specifying whether a stable sort should be used. Default=True.descending (
bool) – boolean specifying whether to sort in descending order. Default=False.kind (
None) – deprecated; instead specify sort algorithm using stable=True or stable=False.order (
None) – not supported by JAXdtype (
str|type[Any] |dtype|SupportsDType|None) – optionally specify the dtype of the resulting indices. If not specified, the default integer dtype will be used.
- Return type:
Array- Returns:
Array of indices that sort an array. Returned array will be of shape
a.shape(ifaxisis an integer) or of shape(a.size,)(ifaxisis None).
Examples
Simple 1-dimensional sort
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> indices = jnp.argsort(x) >>> indices Array([0, 5, 4, 1, 3, 2], dtype=int32) >>> x[indices] Array([1, 1, 2, 3, 4, 5], dtype=int32)
Sort along the last axis of an array:
>>> x = jnp.array([[2, 1, 3], ... [6, 4, 3]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 0, 2], [2, 1, 0]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
See also
jax.numpy.sort(): return sorted values directly.jax.numpy.lexsort(): lexicographical sort of multiple arrays.jax.lax.sort(): lower-level function wrapping XLA’s Sort operator.
- astype(dtype)[source]#
Return a copy of this quantity with the mantissa cast to dtype.
- Parameters:
dtype (
str|type[Any] |dtype|SupportsDType) – Target data type (e.g.jnp.float64).- Returns:
A new quantity with the converted dtype.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.astype(jnp.float64).dtype float64
- property at#
Helper property for index update functionality.
The
atproperty provides a functionally pure equivalent of in-place array modifications.In particular:
Alternate syntax
Equivalent In-place expression
x = x.at[idx].set(y)x[idx] = yx = x.at[idx].add(y)x[idx] += yx = x.at[idx].multiply(y)x[idx] *= yx = x.at[idx].divide(y)x[idx] /= yx = x.at[idx].power(y)x[idx] **= yx = x.at[idx].min(y)x[idx] = minimum(x[idx], y)x = x.at[idx].max(y)x[idx] = maximum(x[idx], y)x = x.at[idx].apply(ufunc)ufunc.at(x, idx)x = x.at[idx].get()x = x[idx]None of the
x.atexpressions modify the originalx; instead they return a modified copy ofx. However, inside ajit()compiled function, expressions likex = x.at[idx].set(y)are guaranteed to be applied in-place.Unlike NumPy in-place operations such as
x[idx] += y, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the
modeparameter (see below).- Parameters:
mode (str) –
Specify out-of-bound indexing mode. Options are:
"promise_in_bounds": (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that out-of-bounds indices inget()will be clipped, and out-of-bounds indices inset(),add(), etc. will be dropped."clip": clamp out of bounds indices into valid range."drop": ignore out-of-bound indices."fill": alias for"drop". For get(), the optionalfill_valueargument specifies the value that will be returned.
indices_are_sorted (bool) – If True, the implementation will assume that the indices passed to
at[]are sorted in ascending order, which can lead to more efficient execution on some backends.unique_indices (bool) – If True, the implementation will assume that the indices passed to
at[]are unique, which can result in more efficient execution on some backends.fill_value (Any) – Only applies to the
get()method: the fill value to return for out-of-bounds slices when mode is'fill'. Ignored otherwise. Defaults toNaNfor inexact types, the largest negative value for signed types, the largest positive value for unsigned types, andTruefor booleans.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> x = jnp.arange(5.0) * u.mV >>> x.at[2].add(10 * u.mV) Quantity([ 0. 1. 12. 3. 4.], "mV") >>> x.at[2].get() Quantity(2., "mV")
- clip(min=None, max=None)[source]#
Clip (limit) the values in the array to
[min, max].At least one of min or max must be given. Both must be compatible with the unit of
self.- Parameters:
- Returns:
The clipped quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.clip(min=u.Quantity(1.5, unit=u.mV), max=u.Quantity(2.5, unit=u.mV)) Quantity([1.5 2. 2.5], "mV")
- clone()[source]#
Return a copy of this quantity (PyTorch-style alias for
copy()).- Returns:
An independent copy.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(3.0, unit=u.mV) >>> q.clone() Quantity(3., "mV")
- conj()[source]#
Return the complex conjugate, element-wise, preserving units.
- Returns:
The conjugated quantity.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(1.0 + 2.0j, unit=u.mV) >>> q.conj() Quantity((1-2j), "mV")
- conjugate()[source]#
Return the complex conjugate, element-wise.
Alias for
conj().- Returns:
The conjugated quantity.
- Return type:
- copy()[source]#
Return a deep copy of this quantity.
- Returns:
An independent copy with the same mantissa and unit.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(3.0, unit=u.mV) >>> q2 = q.copy() >>> q2 Quantity(3., "mV")
- cross(b, axisa=-1, axisb=-1, axisc=-1, axis=None)[source]#
Cross product of two arrays.
The resulting unit is
self.unit * b.unit.- Parameters:
b (
Quantity) – Second operand.axisa (
int) – Axis of self that defines the vector(s) (default-1).axisb (
int) – Axis of b that defines the vector(s) (default-1).axisc (
int) – Axis of the result containing the cross product (default-1).axis (
int) – Overrides axisa, axisb, and axisc simultaneously.
- Returns:
The cross product.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> a = u.Quantity(jnp.array([1.0, 0.0, 0.0]), unit=u.mV) >>> b = u.Quantity(jnp.array([0.0, 1.0, 0.0]), unit=u.second) >>> a.cross(b) Quantity([0. 0. 1.], "mV * s")
- cumprod(*args, **kwds)[source]#
Return the cumulative product of elements along a given axis.
Because each position in the result corresponds to a different number of multiplied elements, the unit exponent varies across the output. This is only representable when the quantity is dimensionless.
- cumsum(axis=None, dtype=None, out=None)#
Cumulative sum of elements along an axis.
JAX implementation of
numpy.cumsum().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – N-dimensional array to be accumulated.axis (
int|None) – integer axis along which to accumulate. If None (default), then array will be flattened and accumulated along the flattened axis.dtype (
str|type[Any] |dtype|SupportsDType|None) – optionally specify the dtype of the output. If not specified, then the output dtype will match the input dtype.out (
None) – unused by JAX
- Return type:
Array- Returns:
An array containing the accumulated sum along the given axis.
See also
jax.numpy.cumulative_sum(): cumulative sum via the array API standard.jax.numpy.add.accumulate(): cumulative sum via ufunc methods.jax.numpy.nancumsum(): cumulative sum ignoring NaN values.jax.numpy.sum(): sum along axis
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.cumsum(x) # flattened cumulative sum Array([ 1, 3, 6, 10, 15, 21], dtype=int32) >>> jnp.cumsum(x, axis=1) # cumulative sum along axis 1 Array([[ 1, 3, 6], [ 4, 9, 15]], dtype=int32)
- diagonal(offset=0, axis1=0, axis2=1)[source]#
Return specified diagonals, preserving units.
- Parameters:
- Returns:
The diagonal elements.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.diagonal() Quantity([1. 4.], "mV")
- property dim: Dimension#
The physical dimension of this quantity (e.g. length, mass, time).
The dimension is independent of scale (metres vs kilometres both have the length dimension).
- Returns:
The physical dimension object.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(5.0, unit=u.metre) >>> q.dim m
See also
unitThe full unit (dimension + scale).
- dot(b)[source]#
Dot product of two arrays.
The resulting unit is
self.unit * b.unit.- Parameters:
b (Quantity or array_like) – Second operand.
- Returns:
The dot product.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> a = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> b = u.Quantity(jnp.array([1.0, 1.0, 1.0]), unit=u.mV) >>> a.dot(b) Quantity(6., "mV^2")
- property dtype#
The data type of the mantissa.
- Returns:
The JAX/NumPy dtype of the underlying array.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.dtype float32
- expand_as(array)[source]#
Expand an array to a shape of another array.
- Parameters:
array (
Quantity|Array|ndarray|bool|number|bool|int|float|complex)- Returns:
expanded – A readonly view on the original array with the given shape of array. It is typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location.
- Return type:
- expand_dims(axis)[source]#
Insert new axes at the given positions.
- Parameters:
axis (
int|Sequence[int]) – Position(s) where the new axis (axes) are placed.- Returns:
The expanded quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.expand_dims(0).shape (1, 2)
- factorless()[source]#
Return an equivalent quantity whose unit has
factor == 1.0.If the unit already has no extra factor the original object is returned unchanged.
- Returns:
A quantity with the factor folded into the mantissa.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(3.0, unit=u.mV) >>> q.factorless() Quantity(3., "mV")
- property flat#
1-D iterator over the mantissa elements, unit preserved.
- flatten()[source]#
Return a 1-D copy of this quantity.
- Returns:
Flattened quantity with the same unit.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.flatten() Quantity([1. 2. 3. 4.], "mV")
- has_same_unit(other)[source]#
Check whether this quantity shares the same physical dimension as other.
Two quantities that differ only in scale (e.g.
mVvsV) are considered to have the same unit dimension.- Parameters:
- Returns:
Trueif both have identical physical dimensions.- Return type:
Examples
>>> import saiunit as u >>> a = u.Quantity(1.0, unit=u.mV) >>> b = u.Quantity(2.0, unit=u.volt) >>> a.has_same_unit(b) True >>> c = u.Quantity(1.0, unit=u.second) >>> a.has_same_unit(c) False
- in_unit(unit, err_msg=None)[source]#
Convert this quantity to a compatible unit.
Behaves identically to
to(); kept for API compatibility.- Parameters:
- Returns:
A new
Quantityexpressed in unit.- Return type:
- Raises:
UnitMismatchError – If unit has a different dimension.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.in_unit(u.volt) Quantity([0.001 0.002 0.003], "V")
- property is_unitless: bool#
Trueif this quantity is dimensionless (has no physical unit).- Returns:
Whether the quantity is unitless.
- Return type:
Examples
>>> import saiunit as u >>> u.Quantity(5.0).is_unitless True >>> u.Quantity(5.0, unit=u.mV).is_unitless False
- item(*args)[source]#
Extract a single element as a scalar
Quantity.- Parameters:
*args (int) – Index into the flat array.
- Returns:
A 0-D
Quantitycontaining the selected element.- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([10.0, 20.0]), unit=u.mV) >>> q.item(0) Quantity(10., "mV")
- property mT: saiunit.Quantity#
Matrix transpose of the last two dimensions, preserving units.
The array must be at least 2-D.
- Returns:
The matrix-transposed quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.mT.shape (2, 2)
- property magnitude: Array | ndarray | bool | number | bool | int | float | complex#
Alias for
mantissa.- Returns:
The raw numerical data of this quantity.
- Return type:
array_like
Examples
>>> import saiunit as u >>> q = u.Quantity(5.0, unit=u.metre) >>> q.magnitude 5.0
See also
mantissaPrimary accessor for the numerical data.
- property mantissa: Array | ndarray | bool | number | bool | int | float | complex#
The raw numerical data of this quantity (without the unit).
In scientific notation \(x = a \times 10^{b}\), the mantissa is the coefficient \(a\). For a
Quantity, it is the underlying JAX/NumPy array (or Python scalar) that stores the numeric value.- Returns:
The mantissa array or scalar.
- Return type:
array_like
Examples
>>> import saiunit as u >>> q = u.Quantity(3.0, unit=u.mV) >>> q.mantissa 3.0
- max(axis=None, out=None, keepdims=False, initial=None, where=None)#
Return the maximum of the array elements along a given axis.
JAX implementation of
numpy.max().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – Input array.axis (
int|Sequence[int] |None) – int or array, default=None. Axis along which the maximum to be computed. If None, the maximum is computed along all the axes.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.initial (
Array|ndarray|bool|number|bool|int|float|complex|None) – int or array, default=None. Initial value for the maximum.where (
Array|ndarray|bool|number|bool|int|float|complex|None) – int or array of boolean dtype, default=None. The elements to be used in the maximum. Array should be broadcast compatible to the input.initialmust be specified whenwhereis used.out (
None) – Unused by JAX.
- Return type:
Array- Returns:
An array of maximum values along the given axis.
See also
jax.numpy.min(): Compute the minimum of array elements along a given axis.jax.numpy.sum(): Compute the sum of array elements along a given axis.jax.numpy.prod(): Compute the product of array elements along a given axis.
Examples
By default,
jnp.maxcomputes the maximum of elements along all the axes.>>> x = jnp.array([[9, 3, 4, 5], ... [5, 2, 7, 4], ... [8, 1, 3, 6]]) >>> jnp.max(x) Array(9, dtype=int32)
If
axis=1, the maximum will be computed along axis 1.>>> jnp.max(x, axis=1) Array([9, 7, 8], dtype=int32)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.max(x, axis=1, keepdims=True) Array([[9], [7], [8]], dtype=int32)
To include only specific elements in computing the maximum, you can use
where. It can either have same dimension as input>>> where=jnp.array([[0, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.max(x, axis=1, keepdims=True, initial=0, where=where) Array([[4], [7], [8]], dtype=int32)
or must be broadcast compatible with input.
>>> where = jnp.array([[False], ... [False], ... [False]]) >>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where) Array([[0, 0, 0, 0]], dtype=int32)
- mean(axis=None, dtype=None, out=None, keepdims=False, *, where=None)#
Return the mean of array elements along a given axis.
JAX implementation of
numpy.mean().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – input array.axis (
int|Sequence[int] |None) – optional, int or sequence of ints, default=None. Axis along which the mean to be computed. If None, mean is computed along all the axes.dtype (
str|type[Any] |dtype|SupportsDType|None) – The type of the output array. If None (default) then the output dtype will be match the input dtype for floating point inputs, or be set to float32 or float64 for non-floating-point inputs.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Array|ndarray|bool|number|bool|int|float|complex|None) – optional, boolean array, default=None. The elements to be used in the mean. Array should be broadcast compatible to the input.out (
None) – Unused by JAX.
- Return type:
Array- Returns:
An array of the mean along the given axis.
Notes
For inputs of type float16 or bfloat16, the reductions will be performed at float32 precision.
See also
jax.numpy.average(): Compute the weighted average of array elementsjax.numpy.sum(): Compute the sum of array elements.
Examples
By default, the mean is computed along all the axes.
>>> x = jnp.array([[1, 3, 4, 2], ... [5, 2, 6, 3], ... [8, 1, 2, 9]]) >>> jnp.mean(x) Array(3.8333335, dtype=float32)
If
axis=1, the mean is computed along axis 1.>>> jnp.mean(x, axis=1) Array([2.5, 4. , 5. ], dtype=float32)
If
keepdims=True,ndimof the output is equal to that of the input.>>> jnp.mean(x, axis=1, keepdims=True) Array([[2.5], [4. ], [5. ]], dtype=float32)
To use only specific elements of
xto compute the mean, you can usewhere.>>> where = jnp.array([[1, 0, 1, 0], ... [0, 1, 0, 1], ... [1, 1, 0, 1]], dtype=bool) >>> jnp.mean(x, axis=1, keepdims=True, where=where) Array([[2.5], [2.5], [6. ]], dtype=float32)
- min(axis=None, out=None, keepdims=False, initial=None, where=None)#
Return the minimum of array elements along a given axis.
JAX implementation of
numpy.min().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – Input array.axis (
int|Sequence[int] |None) – int or array, default=None. Axis along which the minimum to be computed. If None, the minimum is computed along all the axes.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.initial (
Array|ndarray|bool|number|bool|int|float|complex|None) – int or array, Default=None. Initial value for the minimum.where (
Array|ndarray|bool|number|bool|int|float|complex|None) – int or array, default=None. The elements to be used in the minimum. Array should be broadcast compatible to the input.initialmust be specified whenwhereis used.out (
None) – Unused by JAX.
- Return type:
Array- Returns:
An array of minimum values along the given axis.
See also
jax.numpy.max(): Compute the maximum of array elements along a given axis.jax.numpy.sum(): Compute the sum of array elements along a given axis.jax.numpy.prod(): Compute the product of array elements along a given axis.
Examples
By default, the minimum is computed along all the axes.
>>> x = jnp.array([[2, 5, 1, 6], ... [3, -7, -2, 4], ... [8, -4, 1, -3]]) >>> jnp.min(x) Array(-7, dtype=int32)
If
axis=1, the minimum is computed along axis 1.>>> jnp.min(x, axis=1) Array([ 1, -7, -4], dtype=int32)
If
keepdims=True,ndimof the output will be same of that of the input.>>> jnp.min(x, axis=1, keepdims=True) Array([[ 1], [-7], [-4]], dtype=int32)
To include only specific elements in computing the minimum, you can use
where.wherecan either have same dimension as input.>>> where=jnp.array([[1, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.min(x, axis=1, keepdims=True, initial=0, where=where) Array([[ 0], [-2], [-4]], dtype=int32)
or must be broadcast compatible with input.
>>> where = jnp.array([[False], ... [False], ... [False]]) >>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where) Array([[0, 0, 0, 0]], dtype=int32)
- nancumprod(*args, **kwds)[source]#
Return the cumulative product of elements along a given axis, treating NaNs as ones.
Because each position in the result corresponds to a different number of multiplied elements, the unit exponent varies across the output. This is only representable when the quantity is dimensionless.
- nanprod(*args, **kwds)[source]#
Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones.
When reducing along a specific axis, the number of non-NaN elements must be the same for every position in the result so that a single unit exponent can be assigned. If the non-NaN counts differ and the quantity is not dimensionless, a
ValueErroris raised.- Returns:
The product (NaNs treated as ones).
- Return type:
- Raises:
ValueError – If the non-NaN counts are not uniform along the reduction axis for a non-dimensionless quantity.
- nonzero(*, size=None, fill_value=None)#
Return indices of nonzero elements of an array.
JAX implementation of
numpy.nonzero().Because the size of the output of
nonzerois data-dependent, the function is not compatible with JIT and other transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.nonzeroto be used within JAX’s transformations.- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – N-dimensional array.size (
int|None) – optional static integer specifying the number of nonzero entries to return. If there are more nonzero elements than the specifiedsize, then indices will be truncated at the end. If there are fewer nonzero elements than the specified size, then indices will be padded withfill_value, which defaults to zero.fill_value (
None|Array|ndarray|bool|number|bool|int|float|complex|tuple[Array|ndarray|bool|number|bool|int|float|complex,...]) – optional padding value whensizeis specified. Defaults to 0.
- Return type:
- Returns:
Tuple of JAX Arrays of length
a.ndim, containing the indices of each nonzero value.
See also
jax.numpy.flatnonzero()jax.numpy.where()
Examples
One-dimensional array returns a length-1 tuple of indices:
>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jnp.nonzero(x) (Array([1, 3, 5], dtype=int32),)
Two-dimensional array returns a length-2 tuple of indices:
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 7]]) >>> jnp.nonzero(x) (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
In either case, the resulting tuple of indices can be used directly to extract the nonzero values:
>>> indices = jnp.nonzero(x) >>> x[indices] Array([5, 6, 7], dtype=int32)
The output of
nonzerohas a dynamic shape, because the number of returned indices depends on the contents of the input array. As such, it is incompatible with JIT and other JAX transformations:>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jax.jit(jnp.nonzero)(x) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This can be addressed by passing a static
sizeparameter to specify the desired output shape:>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size') >>> nonzero_jit(x, size=3) (Array([1, 3, 5], dtype=int32),)
If
sizedoes not match the true size, the result will be either truncated or padded:>>> nonzero_jit(x, size=2) # size < 3: indices are truncated (Array([1, 3], dtype=int32),) >>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros. (Array([1, 3, 5, 0, 0], dtype=int32),)
You can specify a custom fill value for the padding using the
fill_valueargument:>>> nonzero_jit(x, size=5, fill_value=len(x)) (Array([1, 3, 5, 6, 6], dtype=int32),)
- outer(b)[source]#
Outer product of two 1-D arrays.
The resulting unit is
self.unit * b.unit.Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> a = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> b = u.Quantity(jnp.array([3.0, 4.0]), unit=u.second) >>> a.outer(b).shape (2, 2)
- pow(oc)[source]#
Raise this quantity to the power oc.
The exponent must be dimensionless. The resulting unit is
self.unit ** oc.- Parameters:
- Returns:
self ** oc.- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(2.0, unit=u.mV) >>> q.pow(2) Quantity(4., "mV^2")
- prod(*args, **kwds)[source]#
Return the product of array elements over the given axis.
The unit of the result is
self.unit ** nwhere n is the number of elements multiplied together.- Returns:
The product.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([2.0, 3.0]), unit=u.mV) >>> q.prod() Quantity(6., "mV^2")
- ptp(axis=None, out=None, keepdims=False)#
Return the peak-to-peak range along a given axis.
JAX implementation of
numpy.ptp().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – input array.axis (
int|Sequence[int] |None) – optional, int or sequence of ints, default=None. Axis along which the range is computed. If None, the range is computed on the flattened array.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.out (
None) – Unused by JAX.
- Return type:
Array- Returns:
An array with the range of elements along specified axis of input.
Examples
By default,
jnp.ptpcomputes the range along all axes.>>> x = jnp.array([[1, 3, 5, 2], ... [4, 6, 8, 1], ... [7, 9, 3, 4]]) >>> jnp.ptp(x) Array(8, dtype=int32)
If
axis=1, computes the range along axis 1.>>> jnp.ptp(x, axis=1) Array([4, 7, 6], dtype=int32)
To preserve the dimensions of input, you can set
keepdims=True.>>> jnp.ptp(x, axis=1, keepdims=True) Array([[4], [7], [6]], dtype=int32)
- put(indices, values)[source]#
Replaces specified elements of an array with given values.
- Parameters:
indices (array_like) – Target indices, interpreted as integers.
values (array_like) – Values to place in the array at target indices.
- Return type:
- ravel(order='C', *, out_sharding=None)#
Flatten array into a 1-dimensional shape.
JAX implementation of
numpy.ravel(), implemented in terms ofjax.lax.reshape().ravel(arr, order=order)is equivalent toreshape(arr, -1, order=order).- Parameters:
- Return type:
Array- Returns:
flattened copy of input array.
Notes
Unlike
numpy.ravel(),jax.numpy.ravel()will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.Array.ravel(): equivalent functionality via an array method.jax.numpy.reshape(): general array reshape.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]])
By default, ravel in C-style, row-major order
>>> jnp.ravel(x) Array([1, 2, 3, 4, 5, 6], dtype=int32)
Optionally ravel in Fortran-style, column-major:
>>> jnp.ravel(x, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32)
For convenience, the same functionality is available via the
jax.Array.ravel()method:>>> x.ravel() Array([1, 2, 3, 4, 5, 6], dtype=int32)
- repeat(repeats, axis=None)[source]#
Repeat elements of the array.
- Parameters:
- Returns:
The repeated quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.repeat(2) Quantity([1. 1. 2. 2.], "mV")
- repr_in_unit(precision=None)[source]#
Return a human-readable string of this quantity in its current unit.
The format is
"<value> <unit>", e.g."3. mV"or"[1. 2. 3.] mV".- Parameters:
precision (
int|None) – Number of significant digits. When None the value fromnumpy.get_printoptionsis used.- Returns:
The formatted string.
- Return type:
Examples
>>> import saiunit as u >>> x = u.Quantity(25.0, unit=u.mV) >>> x.repr_in_unit() '25. mV' >>> x.to(u.volt).repr_in_unit(3) '0.025 V'
- reshape(shape, order='C')[source]#
Return a quantity with the same data but a new shape.
- Parameters:
- Returns:
Reshaped quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.reshape((3, 1)).shape (3, 1)
- round(decimals=0)[source]#
Evenly round the mantissa to the given number of decimals.
- Parameters:
decimals (
int) – Number of decimal places (default0). Negative values round to positions left of the decimal point.- Returns:
A new quantity with the rounded mantissa.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(1.567, unit=u.mV) >>> q.round(1) Quantity(1.6, "mV")
- scatter_add(index, value)[source]#
Return a copy with value added at index.
- Parameters:
- Returns:
A new quantity with the update applied.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_add(0, u.Quantity(10.0, unit=u.mV)) Quantity([11. 2. 3.], "mV")
- scatter_div(index, value)[source]#
Return a copy with the element at index divided by value.
value must be dimensionless (a pure scale factor).
- Parameters:
- Returns:
A new quantity with the update applied.
- Return type:
- Raises:
TypeError – If value is not dimensionless.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_div(0, u.Quantity(2.0)) Quantity([0.5 2. 3. ], "mV")
- scatter_max(index, value)[source]#
Return a copy where the element at index is the maximum of the current value and value.
- Parameters:
- Returns:
A new quantity with the update applied.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_max(0, u.Quantity(10.0, unit=u.mV)) Quantity([10. 2. 3.], "mV")
- scatter_min(index, value)[source]#
Return a copy where the element at index is the minimum of the current value and value.
- Parameters:
- Returns:
A new quantity with the update applied.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_min(0, u.Quantity(0.5, unit=u.mV)) Quantity([0.5 2. 3. ], "mV")
- scatter_mul(index, value)[source]#
Return a copy with the element at index multiplied by value.
value must be dimensionless (a pure scale factor).
- Parameters:
- Returns:
A new quantity with the update applied.
- Return type:
- Raises:
TypeError – If value is not dimensionless.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_mul(0, u.Quantity(10.0)) Quantity([10. 2. 3.], "mV")
- scatter_sub(index, value)[source]#
Return a copy with value subtracted at index.
- Parameters:
- Returns:
A new quantity with the update applied.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.scatter_sub(0, u.Quantity(1.0, unit=u.mV)) Quantity([0. 2. 3.], "mV")
- searchsorted(v, side='left', sorter=None)[source]#
Find indices where elements should be inserted to maintain order.
- Return type:
Array
- property shape: tuple[int, ...]#
The shape of the mantissa array.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.shape (2, 2)
- sort(axis=-1, stable=True, order=None)[source]#
Sort the array in-place along the given axis.
- Parameters:
- Returns:
self, with the mantissa sorted in-place.- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([3.0, 1.0, 2.0]), unit=u.mV) >>> q.sort() Quantity([1. 2. 3.], "mV")
- split(indices_or_sections, axis=0)[source]#
Split the array into multiple sub-arrays.
- Parameters:
- Returns:
Sub-arrays, each carrying the same unit.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> parts = q.split(3) >>> len(parts) 3
- squeeze(axis=None)[source]#
Remove length-one axes from the array.
- Parameters:
axis (int or tuple of ints, optional) – Axes to remove. If
None, all length-one axes are removed.- Returns:
The squeezed quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[[1.0]]]), unit=u.mV) >>> q.squeeze().shape ()
- std(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, mean=None, correction=None)#
Compute the standard deviation along a given axis.
JAX implementation of
numpy.std().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – input array.axis (
int|Sequence[int] |None) – optional, int or sequence of ints, default=None. Axis along which the standard deviation is computed. If None, standard deviaiton is computed along all the axes.dtype (
str|type[Any] |dtype|SupportsDType|None) – The type of the output array. Default=None.ddof (
int) – int, default=0. Degrees of freedom. The divisor in the standard deviation computation isN-ddof,Nis number of elements along given axis.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Array|ndarray|bool|number|bool|int|float|complex|None) – optional, boolean array, default=None. The elements to be used in the standard deviation. Array should be broadcast compatible to the input.mean (
Array|ndarray|bool|number|bool|int|float|complex|None) – optional, mean of the input array, computed along the given axis. If provided, it will be used to compute the standard deviation instead of computing it from the input array. If specified, mean must be broadcast-compatible with the input array. In the general case, this can be achieved by computing the mean withkeepdims=Trueandaxismatching this function’saxisargument.correction (
int|float|None) – int or float, default=None. Alternative name forddof. Both ddof and correction can’t be provided simultaneously.out (
None) – Unused by JAX.
- Return type:
Array- Returns:
An array of the standard deviation along the given axis.
See also
jax.numpy.var(): Compute the variance of array elements over given axis.jax.numpy.mean(): Compute the mean of array elements over a given axis.jax.numpy.nanvar(): Compute the variance along a given axis, ignoring NaNs values.jax.numpy.nanstd(): Computed the standard deviation of a given axis, ignoring NaN values.
Examples
By default,
jnp.stdcomputes the standard deviation along all axes.>>> x = jnp.array([[1, 3, 4, 2], ... [4, 2, 5, 3], ... [5, 4, 2, 3]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jnp.std(x) Array(1.21, dtype=float32)
If
axis=0, computes along axis 0.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.std(x, axis=0)) [1.7 0.82 1.25 0.47]
To preserve the dimensions of input, you can set
keepdims=True.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.std(x, axis=0, keepdims=True)) [[1.7 0.82 1.25 0.47]]
If
ddof=1:>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.std(x, axis=0, keepdims=True, ddof=1)) [[2.08 1. 1.53 0.58]]
To include specific elements of the array to compute standard deviation, you can use
where.>>> where = jnp.array([[1, 0, 1, 0], ... [0, 1, 0, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.std(x, axis=0, keepdims=True, where=where) Array([[2., 1., 1., 0.]], dtype=float32)
- property strides#
Tuple of byte-steps in each dimension (mirrors numpy.ndarray.strides).
- sum(axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None, promote_integers=True)#
Sum of the elements of the array over a given axis.
JAX implementation of
numpy.sum().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – Input array.axis (
int|Sequence[int] |None) – int or array, default=None. Axis along which the sum to be computed. If None, the sum is computed along all the axes.dtype (
str|type[Any] |dtype|SupportsDType|None) – The type of the output array. Default=None.out (
None) – Unused by JAXkeepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.initial (
Array|ndarray|bool|number|bool|int|float|complex|None) – int or array, Default=None. Initial value for the sum.where (
Array|ndarray|bool|number|bool|int|float|complex|None) – int or array, default=None. The elements to be used in the sum. Array should be broadcast compatible to the input.promote_integers (
bool) – bool, default=True. If True, then integer inputs will be promoted to the widest available integer dtype, following numpy’s behavior. If False, the result will have the same dtype as the input.promote_integersis ignored ifdtypeis specified.
- Return type:
Array- Returns:
An array of the sum along the given axis.
See also
jax.numpy.prod(): Compute the product of array elements over a given axis.jax.numpy.max(): Compute the maximum of array elements over given axis.jax.numpy.min(): Compute the minimum of array elements over given axis.
Examples
By default, the sum is computed along all the axes.
>>> x = jnp.array([[1, 3, 4, 2], ... [5, 2, 6, 3], ... [8, 1, 3, 9]]) >>> jnp.sum(x) Array(47, dtype=int32)
If
axis=1, the sum is computed along axis 1.>>> jnp.sum(x, axis=1) Array([10, 16, 21], dtype=int32)
If
keepdims=True,ndimof the output is equal to that of the input.>>> jnp.sum(x, axis=1, keepdims=True) Array([[10], [16], [21]], dtype=int32)
To include only specific elements in the sum, you can use
where.>>> where=jnp.array([[0, 0, 1, 0], ... [0, 0, 1, 1], ... [1, 1, 1, 0]], dtype=bool) >>> jnp.sum(x, axis=1, keepdims=True, where=where) Array([[ 4], [ 9], [12]], dtype=int32) >>> where=jnp.array([[False], ... [False], ... [False]]) >>> jnp.sum(x, axis=0, keepdims=True, where=where) Array([[0, 0, 0, 0]], dtype=int32)
- swapaxes(axis1, axis2)[source]#
Interchange two axes of the array.
- Parameters:
- Returns:
The quantity with axes swapped.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.swapaxes(0, 1).shape (2, 2)
- take(indices, axis=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#
Select elements from the array at the given indices.
- Parameters:
indices (array_like) – Indices of the values to extract.
axis (int, optional) – Axis along which to take (default flattened).
mode (str, optional) – Out-of-bounds index handling.
unique_indices (bool, optional) – Hint that indices are unique.
indices_are_sorted (bool, optional) – Hint that indices are sorted.
fill_value (Quantity or scalar, optional) – Value for out-of-bounds positions when mode is
'fill'.
- Returns:
The selected elements.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([10.0, 20.0, 30.0]), unit=u.mV) >>> q.take(jnp.array([0, 2])) Quantity([10. 30.], "mV")
- tile(reps)[source]#
Construct an array by repeating this quantity.
- Parameters:
reps (int or array_like) – Number of repetitions along each axis.
- Returns:
The tiled quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.tile(2) Quantity([1. 2. 1. 2.], "mV")
- to(new_unit)[source]#
Convert this quantity to a different (compatible) unit.
The mantissa is rescaled so that the physical value stays the same, and the returned
Quantitycarries new_unit.- Parameters:
new_unit (
Unit) – Target unit. Must have the same dimension asself.unit.- Returns:
A new
Quantityexpressed in new_unit.- Return type:
- Raises:
UnitMismatchError – If new_unit has a different dimension.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.to(u.volt) Quantity([0.001 0.002 0.003], "V")
See also
in_unitIdentical behaviour (
todelegates toin_unit).to_decimalConvert to a plain number in the target unit.
- to_decimal(unit=Unit('1'))[source]#
Return the numerical value expressed in the given unit, without wrapping the result in a
Quantity.This is useful when you need a plain JAX array for downstream computation that does not support units.
- Parameters:
unit (
Unit) – The reference unit. Defaults toUNITLESS.- Returns:
A plain number or JAX array representing the quantity in unit.
- Return type:
Array|ndarray|bool|number|bool|int|float|complex- Raises:
UnitMismatchError – If unit has a different dimension than
self.unit.
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.to_decimal(u.volt) Array([0.001, 0.002, 0.003], dtype=float32)
See also
toConvert while keeping the
Quantitywrapper.
- tolist()[source]#
Convert the array to a (nested) Python list of
Quantityscalars.Each leaf element is a 0-D
Quantitywith the same unit.- Returns:
A nested list of scalar
Quantityobjects, or a singleQuantityfor 0-D arrays.- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.tolist() [Quantity(1., "mV"), Quantity(2., "mV")]
- trace(offset=0, axis1=0, axis2=1)[source]#
Sum along diagonals of the array, preserving units.
- Parameters:
- Returns:
The trace value(s).
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.eye(3), unit=u.mV) >>> q.trace() Quantity(3., "mV")
- transpose(*axes)[source]#
Return the array with axes transposed.
For a 2-D array this is the standard matrix transpose.
- Parameters:
*axes (None, tuple of ints, or n ints) – If omitted, axes are reversed. Otherwise specifies the permutation.
- Returns:
Transposed quantity.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([[1.0, 2.0], [3.0, 4.0]]), unit=u.mV) >>> q.transpose().shape (2, 2)
- classmethod tree_unflatten(unit, values)[source]#
Tree unflattens the data.
- Parameters:
unit – The unit.
values – The data.
- Return type:
- Returns:
The Quantity object.
- property unit: saiunit.Unit#
The
Unitattached to this quantity.The unit carries both the physical dimension and the scale factor (e.g.
mVhas dimensionvoltagewith scale1e-3).- Returns:
The unit of this quantity.
- Return type:
Examples
>>> import saiunit as u >>> q = u.Quantity(5.0, unit=u.mV) >>> q.unit mV
- unsqueeze(axis)[source]#
Insert a length-one axis (PyTorch-style alias for
expand_dims()).- Parameters:
axis (
int) – Position where the new axis is inserted.- Returns:
The quantity with an extra dimension.
- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0]), unit=u.mV) >>> q.unsqueeze(0).shape (1, 2)
- update_mantissa(mantissa)[source]#
Replace the mantissa in-place, keeping the same unit.
The new mantissa must have the same shape and dtype as the current one.
- Parameters:
mantissa (
Any) – The new numerical data. Must not be aQuantity.- Raises:
ValueError – If mantissa is a
Quantity, or if shape/dtype do not match.- Return type:
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> q = u.Quantity(jnp.array([1.0, 2.0, 3.0]), unit=u.mV) >>> q.update_mantissa(jnp.array([4.0, 5.0, 6.0])) >>> q Quantity([4. 5. 6.], "mV")
- var(axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, mean=None, correction=None)#
Compute the variance along a given axis.
JAX implementation of
numpy.var().- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex) – input array.axis (
int|Sequence[int] |None) – optional, int or sequence of ints, default=None. Axis along which the variance is computed. If None, variance is computed along all the axes.dtype (
str|type[Any] |dtype|SupportsDType|None) – The type of the output array. Default=None.ddof (
int) – int, default=0. Degrees of freedom. The divisor in the variance computation isN-ddof,Nis number of elements along given axis.keepdims (
bool) – bool, default=False. If true, reduced axes are left in the result with size 1.where (
Array|ndarray|bool|number|bool|int|float|complex|None) – optional, boolean array, default=None. The elements to be used in the variance. Array should be broadcast compatible to the input.mean (
Array|ndarray|bool|number|bool|int|float|complex|None) – optional, mean of the input array, computed along the given axis. If provided, it will be used to compute the variance instead of computing it from the input array. If specified, mean must be broadcast-compatible with the input array. In the general case, this can be achieved by computing the mean withkeepdims=Trueandaxismatching this function’saxisargument.correction (
int|float|None) – int or float, default=None. Alternative name forddof. Both ddof and correction can’t be provided simultaneously.out (
None) – Unused by JAX.
- Return type:
Array- Returns:
An array of the variance along the given axis.
See also
jax.numpy.mean(): Compute the mean of array elements over a given axis.jax.numpy.std(): Compute the standard deviation of array elements over given axis.jax.numpy.nanvar(): Compute the variance along a given axis, ignoring NaNs values.jax.numpy.nanstd(): Computed the standard deviation of a given axis, ignoring NaN values.
Examples
By default,
jnp.varcomputes the variance along all axes.>>> x = jnp.array([[1, 3, 4, 2], ... [5, 2, 6, 3], ... [8, 4, 2, 9]]) >>> with jnp.printoptions(precision=2, suppress=True): ... jnp.var(x) Array(5.74, dtype=float32)
If
axis=1, variance is computed along axis 1.>>> jnp.var(x, axis=1) Array([1.25 , 2.5 , 8.1875], dtype=float32)
To preserve the dimensions of input, you can set
keepdims=True.>>> jnp.var(x, axis=1, keepdims=True) Array([[1.25 ], [2.5 ], [8.1875]], dtype=float32)
If
ddof=1:>>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.var(x, axis=1, keepdims=True, ddof=1)) [[ 1.67] [ 3.33] [10.92]]
To include specific elements of the array to compute variance, you can use
where.>>> where = jnp.array([[1, 0, 1, 0], ... [0, 1, 1, 0], ... [1, 1, 1, 0]], dtype=bool) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.var(x, axis=1, keepdims=True, where=where)) [[2.25] [4. ] [6.22]]
- view(*args, dtype=None)[source]#
New view of array with the same data.
This function is compatible with pytorch syntax.
Returns a new tensor with the same data as the
selftensor but of a differentshape.The returned tensor shares the same data and must have the same number of elements, but may have a different size. For a tensor to be viewed, the new view size must be compatible with its original size and stride, i.e., each new view dimension must either be a subspace of an original dimension, or only span across original dimensions \(d, d+1, \dots, d+k\) that satisfy the following contiguity-like condition that \(\forall i = d, \dots, d+k-1\),
\[\text{stride}[i] = \text{stride}[i+1] \times \text{size}[i+1]\]Otherwise, it will not be possible to view
selftensor asshapewithout copying it (e.g., viacontiguous()). When it is unclear whether aview()can be performed, it is advisable to usereshape(), which returns a view if the shapes are compatible, and copies (equivalent to callingcontiguous()) otherwise.Example:
>>> import jax.numpy as jnp, saiunit >>> x = saiunit.Quantity(jnp.ones((4, 4))) >>> x.shape (4, 4) >>> y = x.view(16) >>> y.shape (16,) >>> z = x.view(2, 8) >>> z.shape (2, 8)
- view(dtype) Tensor[source]
Returns a new tensor with the same data as the
selftensor but of a differentdtype.If the element size of
dtypeis different than that ofself.dtype, then the size of the last dimension of the output will be scaled proportionally. For instance, ifdtypeelement size is twice that ofself.dtype, then each pair of elements in the last dimension ofselfwill be combined, and the size of the last dimension of the output will be half that ofself. Ifdtypeelement size is half that ofself.dtype, then each element in the last dimension ofselfwill be split in two, and the size of the last dimension of the output will be double that ofself. For this to be possible, the following conditions must be true:self.dim()must be greater than 0.self.stride(-1)must be 1.
Additionally, if the element size of
dtypeis greater than that ofself.dtype, the following conditions must be true as well:self.size(-1)must be divisible by the ratio between the element sizes of the dtypes.self.storage_offset()must be divisible by the ratio between the element sizes of the dtypes.The strides of all dimensions, except the last dimension, must be divisible by the ratio between the element sizes of the dtypes.
If any of the above conditions are not met, an error is thrown.
- Parameters:
dtype (
dtype) – the desired dtype
Example:
>>> x = brainstate.random.randn(4, 4) >>> x Array([[ 0.9482, -0.0310, 1.4999, -0.5316], [-0.1520, 0.7472, 0.5617, -0.8649], [-2.4724, -0.0334, -0.2976, -0.8499], [-0.2109, 1.9913, -0.9607, -0.6123]]) >>> x.dtype brainstate.math.float32 >>> y = x.view(numpy.int32) >>> y tensor([[ 1064483442, -1124191867, 1069546515, -1089989247], [-1105482831, 1061112040, 1057999968, -1084397505], [-1071760287, -1123489973, -1097310419, -1084649136], [-1101533110, 1073668768, -1082790149, -1088634448]], dtype=numpy.int32) >>> y[0, 0] = 1000000000 >>> x tensor([[ 0.0047, -0.0310, 1.4999, -0.5316], [-0.1520, 0.7472, 0.5617, -0.8649], [-2.4724, -0.0334, -0.2976, -0.8499], [-0.2109, 1.9913, -0.9607, -0.6123]]) >>> x.view(numpy.complex64) tensor([[ 0.0047-0.0310j, 1.4999-0.5316j], [-0.1520+0.7472j, 0.5617-0.8649j], [-2.4724-0.0334j, -0.2976-0.8499j], [-0.2109+1.9913j, -0.9607-0.6123j]]) >>> x.view(numpy.complex64).size [4, 2] >>> x.view(numpy.uint8) tensor([[ 0, 202, 154, 59, 182, 243, 253, 188, 185, 252, 191, 63, 240, 22, 8, 191], [227, 165, 27, 190, 128, 72, 63, 63, 146, 203, 15, 63, 22, 106, 93, 191], [205, 59, 30, 192, 112, 206, 8, 189, 7, 95, 152, 190, 12, 147, 89, 191], [ 43, 246, 87, 190, 235, 226, 254, 63, 111, 240, 117, 191, 177, 191, 28, 191]], dtype=uint8) >>> x.view(numpy.uint8).size [4, 16]
- static with_unit(mantissa, unit)[source]#
Create a
Quantityfrom a raw value and a unit.This is a convenience factory that reads more naturally in some contexts than the standard constructor.
- Parameters:
- Returns:
A new
Quantitywith the given mantissa and unit.- Return type:
Examples
>>> import saiunit as u >>> u.Quantity.with_unit(2.0, unit=u.metre) Quantity(2., "m")