saiunit.math module#

Unit-aware mathematical functions for JAX arrays.

This subpackage provides unit-aware wrappers for NumPy-style functions, organized by how they handle units:

  • Array creation: array, zeros, ones, arange, linspace, etc.

  • Keep unit: functions that preserve the input unit (sum, mean, concatenate, reshape, abs, round, etc.).

  • Change unit: functions whose output unit differs from the input (multiply, divide, square, sqrt, dot, matmul, etc.).

  • Remove unit: functions that return dimensionless results (equal, greater, argmax, argsort, sign, etc.).

  • Accept unitless: functions that require unitless inputs (exp, log, sin, cos, arctan, etc.).

  • Activations: neural-network activation functions (relu, sigmoid, gelu, silu, etc.).

  • Einops: Einstein-notation operations (einsum, einrearrange, etc.).

Activation Functions#

relu

Rectified linear unit activation function.

relu6

Rectified Linear Unit 6 activation function.

sigmoid

Sigmoid activation function.

softplus

Softplus activation function.

sparse_plus

Sparse plus function.

sparse_sigmoid

Sparse sigmoid activation function.

soft_sign

Soft-sign activation function.

silu

SiLU (aka swish) activation function.

swish

Swish (aka SiLU) activation function.

log_sigmoid

Log-sigmoid activation function.

leaky_relu

Leaky rectified linear unit activation function.

hard_sigmoid

Hard Sigmoid activation function.

hard_silu

Hard SiLU (swish) activation function.

hard_swish

Hard SiLU (swish) activation function.

hard_tanh

Hard \(\mathrm{tanh}\) activation function.

elu

Exponential linear unit activation function.

celu

Continuously-differentiable exponential linear unit activation.

selu

Scaled exponential linear unit activation.

gelu

Gaussian error linear unit activation function.

glu

Gated linear unit activation function.

squareplus

Squareplus activation function.

mish

Mish activation function.

Unit Processing#

is_dimensionless

Test if a value is dimensionless or not.

is_unitless

Test if a value is unitless or not.

get_dim

Return the dimension of any object that has them.

get_unit

Return the unit of any object that has them.

get_mantissa

Return the mantissa of a Quantity or a number.

get_magnitude

Return the mantissa of a Quantity or a number.

display_in_unit

Display a value in a certain unit with a given precision.

maybe_decimal

Convert a quantity to a plain number if it is dimensionless.

check_dims

Decorator to check dimensions of arguments passed to a function

check_units

Decorator to check units of arguments passed to a function

fail_for_dimension_mismatch

Compare the dimensions of two objects.

fail_for_unit_mismatch

Compare the units of two objects.

assert_quantity

Assert that a Quantity has a certain mantissa and unit.

get_or_create_dimension

Create a new Dimension object or get a reference to an existing one.

Einstein Operations#

einreduce

Combine reordering and reduction using reader-friendly notation.

einrearrange

Reader-friendly smart element reordering for multidimensional tensors.

einrepeat

Reorder elements and repeat them in arbitrary combinations.

einshape

Parse a tensor shape to a dictionary mapping axis names to their lengths.

einsum

Einstein summation for arrays and quantities.

Functions that Accepting Unitless#

exprel

Relative error exponential, (exp(x) - 1)/x.

set_exprel_order

Set the Taylor series order used by exprel() near zero.

exp

Calculate the exponential of all elements in the input.

exp2

Calculate 2**x element-wise.

expm1

Calculate exp(x) - 1 element-wise with improved precision near zero.

log

Natural logarithm, element-wise.

log10

Base-10 logarithm, element-wise.

log1p

Natural logarithm of 1 + x, element-wise.

log2

Base-2 logarithm, element-wise.

arccos

Inverse cosine, element-wise.

arccosh

Inverse hyperbolic cosine, element-wise.

arcsin

Inverse sine, element-wise.

arcsinh

Inverse hyperbolic sine, element-wise.

arctan

Inverse tangent, element-wise.

arctanh

Inverse hyperbolic tangent, element-wise.

cos

Cosine, element-wise.

cosh

Hyperbolic cosine, element-wise.

sin

Sine, element-wise.

sinc

Normalized sinc function, sin(pi*x) / (pi*x), element-wise.

sinh

Hyperbolic sine, element-wise.

tan

Tangent, element-wise.

tanh

Hyperbolic tangent, element-wise.

deg2rad

Convert angles from degrees to radians.

rad2deg

Convert angles from radians to degrees.

degrees

Convert angles from radians to degrees (alias for rad2deg()).

radians

Convert angles from degrees to radians (alias for deg2rad()).

angle

Return the angle of the complex argument, element-wise.

frexp

Decompose elements into mantissa and base-2 exponent.

hypot

Given the legs of a right triangle, return its hypotenuse.

arctan2

Element-wise arc tangent of x / y choosing the quadrant correctly.

logaddexp

Logarithm of the sum of exponentiations of the inputs.

logaddexp2

Logarithm of the sum of exponentiations of the inputs in base 2.

corrcoef

Return Pearson product-moment correlation coefficients.

correlate

Cross-correlation of two 1-dimensional sequences.

cov

Estimate a covariance matrix, given data and weights.

ldexp

Returns x * 2**y, element-wise.

bitwise_not

Compute bit-wise NOT, element-wise.

invert

Compute bit-wise inversion (NOT), element-wise.

bitwise_and

Compute bit-wise AND of two arrays, element-wise.

bitwise_or

Compute bit-wise OR of two arrays, element-wise.

bitwise_xor

Compute bit-wise XOR of two arrays, element-wise.

left_shift

Shift the bits of an integer to the left, element-wise.

right_shift

Shift the bits of an integer to the right, element-wise.

Array Creation Functions#

full

Return a new quantity or array of given shape, filled with fill_value.

eye

Return a 2-D identity-like quantity or array with ones on the diagonal.

identity

Return the identity quantity or array.

tri

Return an array with ones at and below the given diagonal and zeros elsewhere.

empty

Return a new quantity or array of given shape and type, without initializing entries.

ones

Return a new quantity or array of given shape and type, filled with ones.

zeros

Return a new quantity or array of given shape and type, filled with zeros.

full_like

Return a new quantity or array with the same shape and type as a given array, filled with fill_value.

diag

Extract a diagonal or construct a diagonal array.

tril

Return the lower triangle of an array.

triu

Return the upper triangle of an array.

empty_like

Return a new uninitialized quantity or array with the same shape and type as a given array.

ones_like

Return a quantity or array of ones with the same shape and type as a given array.

zeros_like

Return a quantity or array of zeros with the same shape and type as a given array.

fill_diagonal

Fill the main diagonal of the given array of any dimensionality.

array

Convert the input to a quantity or array.

asarray

Convert the input to a quantity or array.

arange

Return evenly spaced values within a given interval.

linspace

Return evenly spaced numbers over a specified interval.

logspace

Return numbers spaced evenly on a log scale.

meshgrid

Return coordinate matrices from coordinate vectors.

vander

Generate a Vandermonde matrix.

tril_indices

Return the indices of lower triangle of an array of size (n, m).

tril_indices_from

Return the indices for the lower-triangle of an (n, m) array.

triu_indices

Return the indices of upper triangle of an array of size (n, m).

triu_indices_from

Return the indices for the upper-triangle of an (n, m) array.

from_numpy

Convert a NumPy array to a JAX array, optionally attaching a unit.

as_numpy

Convert a JAX array (or Quantity) to a NumPy array.

tree_ones_like

Create a tree with the same structure as the input, but with ones in each leaf.

tree_zeros_like

Create a tree with the same structure as the input, but with zeros in each leaf.

Functions that Changing Unit#

reciprocal

Return the reciprocal of the argument, element-wise.

prod

Return the product of array elements over a given axis.

product

Return the product of array elements over a given axis.

nancumprod

Return the cumulative product of elements along a given axis treating NaNs as one.

nanprod

Return the product of array elements over a given axis treating NaNs as one.

cumprod

Return the cumulative product of elements along a given axis.

cumproduct

Return the cumulative product of elements along a given axis.

var

Compute the variance along the specified axis.

nanvar

Compute the variance along the specified axis, while ignoring NaNs.

cbrt

Compute the cube root of each element.

square

Compute the square of each element.

sqrt

Compute the positive square root of each element.

multiply

Multiply arguments element-wise.

divide

Divide arguments element-wise.

power

First array elements raised to powers from second array, element-wise.

cross

Return the cross product of two (arrays of) vectors.

true_divide

Return a true division of the inputs, element-wise.

floor_divide

Return the largest integer smaller or equal to the division of the inputs.

float_power

First array elements raised to powers from second array, element-wise.

divmod

Return element-wise quotient and remainder simultaneously.

convolve

Return the discrete, linear convolution of two one-dimensional sequences.

dot

Compute the dot product of two arrays or quantities.

multi_dot

Efficiently compute matrix products between a sequence of arrays.

vdot

Perform a conjugate multiplication of two 1D vectors.

vecdot

Perform a conjugate multiplication of two batched vectors.

inner

Compute the inner product of two arrays or quantities.

outer

Compute the outer product of two vectors or quantities.

kron

Compute the Kronecker product of two arrays or quantities.

matmul

Compute the matrix product of two arrays or quantities.

tensordot

Compute tensor dot product along specified axes.

matrix_power

Raise a square matrix to the (integer) power n.

Functions that Keeping Unit#

row_stack

Stack quantities or arrays in sequence vertically (row wise).

concatenate

Join a sequence of quantities or arrays along an existing axis.

stack

Join a sequence of quantities or arrays along a new axis.

vstack

Stack quantities or arrays in sequence vertically (row wise).

hstack

Stack quantities arrays in sequence horizontally (column wise).

dstack

Stack quantities or arrays in sequence depth wise (along third axis).

column_stack

Stack 1-D arrays as columns into a 2-D array.

block

Assemble a quantity or an array from nested lists of blocks.

append

Append values to the end of a quantity or an array.

split

Split quantity or array into a list of multiple sub-arrays.

array_split

Split an array into multiple sub-arrays.

dsplit

Split a quantity or an array into multiple sub-arrays along the 3rd axis (depth).

hsplit

Split a quantity or an array into multiple sub-arrays horizontally (column-wise).

vsplit

Split a quantity or an array into multiple sub-arrays vertically (row-wise).

atleast_1d

View inputs as quantities or arrays with at least one dimension.

atleast_2d

View inputs as quantities or arrays with at least two dimensions.

atleast_3d

View inputs as quantities or arrays with at least three dimensions.

broadcast_arrays

Broadcast any number of arrays against each other.

broadcast_to

Broadcast an array to a new shape.

reshape

Gives a new shape to a quantity or an array without changing its data.

moveaxis

Moves axes of a quantity or an array to new positions.

transpose

Permute the dimensions of a quantity or an array.

swapaxes

Interchange two axes of a quantity or an array.

tile

Construct a quantity or an array by repeating A the number of times given by reps.

repeat

Repeat elements of a quantity or an array.

flip

Reverse the order of elements in a quantity or an array along the given axis.

fliplr

Flip quantity or array in the left/right direction.

flipud

Flip quantity or array in the up/down direction.

roll

Roll quantity or array elements along a given axis.

expand_dims

Expand the shape of a quantity or an array.

squeeze

Remove single-dimensional entries from the shape of a quantity or an array.

sort

Return a sorted copy of a quantity or an array.

max

Return the maximum of a quantity or an array or maximum along an axis.

min

Return the minimum of a quantity or an array or minimum along an axis.

amax

Return the maximum of a quantity or an array or maximum along an axis.

amin

Return the minimum of a quantity or an array or minimum along an axis.

diagflat

Create a two-dimensional a quantity or array with the flattened input as a diagonal.

diagonal

Return specified diagonals.

choose

Construct a quantity or an array from an index array and a set of arrays to choose from.

ravel

Return a contiguous flattened quantity or array.

flatten

Flattens input by reshaping it into a one-dimensional tensor.

unflatten

Expands a dimension of the input tensor over multiple dimensions.

remove_diag

Remove the diagonal of the matrix.

astype

Copy of the array, cast to a specified type.

real

Return the real part of the complex argument.

imag

Return the imaginary part of the complex argument.

conj

Return the complex conjugate of the argument.

conjugate

Return the complex conjugate of the argument.

negative

Return the negative of the argument.

positive

Return the positive of the argument.

abs

Return the absolute value of the argument.

sum

Return the sum of the array elements.

nancumsum

Return the cumulative sum of the array elements, ignoring NaNs.

nansum

Return the sum of the array elements, ignoring NaNs.

cumsum

Return the cumulative sum of the array elements.

ediff1d

Return the differences between consecutive elements of the array.

absolute

Return the absolute value of the argument.

fabs

Return the absolute value of the argument.

median

Return the median of the array elements.

nanmin

Return the minimum of the array elements, ignoring NaNs.

nanmax

Return the maximum of the array elements, ignoring NaNs.

ptp

Return the range of the array elements (maximum - minimum).

average

Return the weighted average of the array elements.

mean

Return the mean of the array elements.

std

Return the standard deviation of the array elements.

nanmedian

Return the median of the array elements, ignoring NaNs.

nanmean

Return the mean of the array elements, ignoring NaNs.

nanstd

Return the standard deviation of the array elements, ignoring NaNs.

diff

Return the differences between consecutive elements of the array.

rot90

Rotate an array by 90 degrees in the plane specified by axes.

intersect1d

Find the intersection of two arrays.

nan_to_num

Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the nan, posinf and/or neginf keywords.

percentile

Compute the q-th percentile of the data along the specified axis.

nanpercentile

Compute the q-th percentile of the data along the specified axis, while ignoring nan values.

quantile

Compute the q-th percentile of the data along the specified axis.

nanquantile

Compute the q-th percentile of the data along the specified axis, while ignoring nan values.

round

Round an array to the nearest integer.

around

Round an array to the nearest integer.

rint

Round an array to the nearest integer.

floor

Return the floor of the argument.

ceil

Return the ceiling of the argument.

trunc

Return the truncated value of the argument.

fix

Return the nearest integer towards zero.

modf

Return the fractional and integer parts of the array elements.

fmod

Return the element-wise remainder of division.

mod

Return the element-wise modulus of division.

copysign

Return a copy of the first array elements with the sign of the second array.

remainder

Returns the element-wise remainder of division.

maximum

Element-wise maximum of array elements.

minimum

Element-wise minimum of array elements.

fmax

Element-wise maximum of array elements ignoring NaNs.

fmin

Element-wise minimum of array elements ignoring NaNs.

lcm

Return the least common multiple of x1 and x2.

gcd

Return the greatest common divisor of x1 and x2.

trace

Return the sum along diagonals of the array.

add

Add arguments element-wise.

subtract

Subtract arguments, element-wise.

nextafter

Return the next floating-point value after x towards y, element-wise.

promote_dtypes

Promote the data types of the inputs to a common type.

interp

One-dimensional linear interpolation.

clip

Clip (limit) the values in an array.

histogram

Compute the histogram of a set of data.

compress

Return selected slices of a quantity or an array along given axis.

extract

Return the elements of an array that satisfy some condition.

take

Take elements from an array along an axis.

select

Return an array drawn from elements in choicelist, depending on conditions.

where

Return elements chosen from x or y depending on condition.

unique

Find the unique elements of a quantity or an array.

gather

Gather values along an axis specified by dim, according to index.

Functions that Removing Unit#

iscomplexobj

Return True if x is a complex type or an array of complex numbers.

heaviside

Compute the Heaviside step function.

signbit

Return element-wise True where the sign bit is set (less than zero).

sign

Return the sign of each element in the input array.

bincount

Count number of occurrences of each value in array of non-negative ints.

digitize

Return the indices of the bins to which each value in input array belongs.

get_promote_dtypes

Promote the data types of the inputs to a common type.

all

Test whether all array elements along a given axis evaluate to True.

any

Test whether any array element along a given axis evaluates to True.

logical_not

Compute the truth value of NOT x element-wise.

equal

Return (x == y) element-wise.

not_equal

Return (x != y) element-wise.

greater

Return (x > y) element-wise.

greater_equal

Return (x >= y) element-wise.

less

Return (x < y) element-wise.

less_equal

Return (x <= y) element-wise.

array_equal

Return True if two arrays have the same shape and elements.

isclose

Returns a boolean array where two arrays are element-wise equal within a tolerance.

allclose

Returns True if two arrays are element-wise equal within a tolerance.

logical_and

Compute the truth value of x AND y element-wise.

logical_or

Compute the truth value of x OR y element-wise.

logical_xor

Compute the truth value of x XOR y element-wise.

alltrue

Test whether all array elements along a given axis evaluate to True.

sometrue

Test whether any array element along a given axis evaluates to True.

argsort

Return the indices that would sort an array or Quantity.

argmax

Return the index of the maximum value along an axis.

argmin

Return the index of the minimum value along an axis.

nanargmax

Return the index of the maximum value, ignoring NaNs.

nanargmin

Return the index of the minimum value, ignoring NaNs.

argwhere

Find the indices of array elements that are non-zero.

nonzero

Return the indices of non-zero elements.

flatnonzero

Return indices that are non-zero in the flattened input.

searchsorted

Find indices where elements should be inserted to maintain order.

count_nonzero

Count the number of non-zero values in the input.

diag_indices_from

Return indices for accessing the main diagonal of a given array.

Other Functions#

finfo

iinfo

is_quantity

Check whether x is a Quantity instance.

issubdtype

Check if a dtype is a sub-dtype of another in the type hierarchy.

result_type

Determine the result dtype from a set of input arrays or dtypes.

ndim

Return the number of dimensions of an array or Quantity.

isreal

Test element-wise whether each element is real (has zero imaginary part).

isscalar

Return True if the input is a scalar (zero-dimensional).

isfinite

Test element-wise for finiteness (not inf and not NaN).

isinf

Test element-wise for positive or negative infinity.

isnan

Test element-wise for NaN.

shape

Return the shape of an array.

size

Return the number of elements along a given axis.

get_dtype

Get the dtype of an array, Quantity, or Python scalar.

is_float

Check if the array has a floating-point dtype.

is_int

Check if the array has an integer dtype.

broadcast_shapes

Broadcast a sequence of array shapes.

gradient

Computes the gradient of a scalar field.

bartlett

Return a Bartlett window of size M.

blackman

Return a Blackman window of size M.

hamming

Return a Hamming window of size M.

hanning

Return a Hanning window of size M.

kaiser

Return a Kaiser window of size M.

bool_

uint2

A JAX scalar constructor of type uint2.

uint4

A JAX scalar constructor of type uint4.

uint8

A JAX scalar constructor of type uint8.

uint16

A JAX scalar constructor of type uint16.

uint32

A JAX scalar constructor of type uint32.

uint64

A JAX scalar constructor of type uint64.

int2

A JAX scalar constructor of type int2.

int4

A JAX scalar constructor of type int4.

int8

A JAX scalar constructor of type int8.

int16

A JAX scalar constructor of type int16.

int32

A JAX scalar constructor of type int32.

int64

A JAX scalar constructor of type int64.

bfloat16

A JAX scalar constructor of type bfloat16.

float16

A JAX scalar constructor of type float16.

float32

A JAX scalar constructor of type float32.

float64

A JAX scalar constructor of type float64.

complex64

A JAX scalar constructor of type complex64.

complex128

A JAX scalar constructor of type complex128.

int_

uint

float_

complex_

single

double

csingle

cdouble

inexact

Abstract base class of all numeric scalar types with a (potentially) inexact representation of the values in its range, such as floating-point numbers.

dtype

e

Convert a string or number to a floating-point number, if possible.

pi

Convert a string or number to a floating-point number, if possible.

inf

Convert a string or number to a floating-point number, if possible.

nan

Convert a string or number to a floating-point number, if possible.

euler_gamma

Convert a string or number to a floating-point number, if possible.

newaxis

The type of the None singleton.