NumPy Functions#

Colab Open in Kaggle

brainunit.math provides 500+ unit-aware functions compatible with NumPy/JAX. They are categorized by how they handle units:

  1. Array Creation — create arrays with or without units

  2. Functions Accepting Unitless — require dimensionless input (trig, exp, log)

  3. Functions Changing Unit — output unit differs from input (multiply, sqrt, dot)

  4. Functions Keeping Unit — output unit matches input (sort, sum, reshape)

  5. Functions Removing Unit — return dimensionless results (comparisons, argmax)

import brainunit as u
import jax.numpy as jnp

1. Array Creation#

Functions for creating arrays with units.

Includes: array, asarray, zeros, ones, full, eye, arange, linspace, logspace, meshgrid, and their _like variants.

# zeros and ones with unit
print('zeros:', u.math.zeros(3, unit=u.volt))
print('ones:', u.math.ones((2, 2), unit=u.meter))
zeros: [0. 0. 0.] V
ones: [[1. 1.]
 [1. 1.]] m
# arange with Quantity endpoints
times = u.math.arange(0 * u.ms, 10 * u.ms, step=2 * u.ms)
print('arange:', times)
arange: 
[0 2 4 6 8] ms
# linspace between Quantity endpoints
voltages = u.math.linspace(0 * u.mV, 100 * u.mV, 5)
print('linspace:', voltages)
linspace: [  0.  25.  50.  75. 100.] mV
# eye with unit
print('eye:')
print(u.math.eye(3, unit=u.ohm))
eye:
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]] ohm
# full_like: create array with same shape as a Quantity
template = jnp.array([1., 2., 3.]) * u.newton
print('full_like:', u.math.full_like(template, 99.0 * u.newton))
full_like: [99. 99. 99.] N

2. Functions Accepting Unitless#

These functions require dimensionless inputs. They perform transcendental or trigonometric operations that are not physically meaningful with units.

Includes: exp, log, sin, cos, tan, arcsin, arctan2, sinh, cosh, deg2rad, rad2deg, logaddexp, and more.

# Trigonometric functions on dimensionless values
angles = jnp.array([0., jnp.pi/6, jnp.pi/4, jnp.pi/3, jnp.pi/2])
print('sin:', u.math.sin(angles))
print('cos:', u.math.cos(angles))
sin: [0.         0.5        0.70710677 0.86602545 1.        ]
cos: [ 1.0000000e+00  8.6602539e-01  7.0710677e-01  4.9999997e-01
 -4.3711388e-08]
# Exponential and logarithm
x = jnp.array([0., 1., 2.])
print('exp:', u.math.exp(x))
print('log:', u.math.log(u.math.exp(x)))  # round-trip
exp: [1.        2.7182817 7.389056 ]
log: [0.         0.99999994 2.        ]
# Passing a Quantity with units raises an error
try:
    u.math.exp(2.0 * u.meter)
except Exception as e:
    print(type(e).__name__, ':', e)
TypeError : exp requires a dimensionless "x" when "unit_to_scale" is not provided. Got Quantity(unit=m, dim=m). Pass "unit_to_scale=<Unit>" to scale before applying exp, or convert explicitly to a dimensionless value first.
# Convert to dimensionless first using to_decimal
ratio = (5.0 * u.mV).to_decimal(u.volt)  # 0.005, dimensionless
print('exp(ratio):', u.math.exp(ratio))
exp(ratio): 1.0050125
# Angle conversions
print('deg2rad(180):', u.math.deg2rad(180.0))
print('rad2deg(pi):', u.math.rad2deg(jnp.pi))
deg2rad(180): 3.1415927
rad2deg(pi): 180.0
# arctan2 for 2D angle computation
y = jnp.array([1., 0., -1.])
x = jnp.array([0., 1., 0.])
print('arctan2:', u.math.arctan2(y, x))
arctan2: [ 1.5707964  0.        -1.5707964]

3. Functions Changing Unit#

These functions produce outputs with different units than their inputs, following the mathematical rules of the operation.

Includes: multiply, divide, power, sqrt, square, reciprocal, prod, dot, matmul, inner, outer, kron, cross, convolve, and more.

# multiply: units multiply
force = jnp.array([10., 20.]) * u.newton
distance = jnp.array([5., 3.]) * u.meter
work = u.math.multiply(force, distance)
print('F * d (work):', work)  # N * m = J
F * d (work): 
[50. 60.] J
# divide: units divide
print('d / t (speed):', u.math.divide(100.0 * u.meter, 10.0 * u.second))
d / t (speed): 10. m / s
# sqrt and square
area = jnp.array([4., 9., 16.]) * u.meter2
print('sqrt(area):', u.math.sqrt(area))  # m^2 -> m

side = jnp.array([2., 3., 4.]) * u.meter
print('square(side):', u.math.square(side))  # m -> m^2
sqrt(area): [2. 3. 4.] m
square(side): [ 4.  9. 16.] m^2
# reciprocal
resistance = jnp.array([10., 50.]) * u.ohm
print('1/R (conductance):', u.math.reciprocal(resistance))  # 1/ohm = S
1/R (conductance): [0.1  0.02] S
# dot product
a = jnp.array([1., 2., 3.]) * u.meter
b = jnp.array([4., 5., 6.]) * u.meter
print('dot(a, b):', u.math.dot(a, b))  # m^2
dot(a, b): 32. m^2
# outer product
x = jnp.array([1., 2.]) * u.volt
y = jnp.array([3., 4., 5.]) * u.ampere
print('outer(V, A):')
print(u.math.outer(x, y))  # V * A = W
outer(V, A):
[[ 3.  4.  5.]
 [ 6.  8. 10.]] W
# prod: unit raised to the power of number of elements
vals = jnp.array([2., 3., 4.]) * u.meter
print('prod:', u.math.prod(vals))  # 24 m^3
prod: 24. m^3

4. Functions Keeping Unit#

These functions preserve the unit of the input in the output.

Includes: sum, mean, std, min, max, sort, cumsum, reshape, transpose, concatenate, stack, split, flip, roll, clip, abs, negative, diff, interp, where, unique, and many more.

# Statistical functions
data = jnp.array([1., 3., 5., 7., 9.]) * u.volt
print('sum:', u.math.sum(data))
print('mean:', u.math.mean(data))
print('std:', u.math.std(data))
print('min:', u.math.min(data))
print('max:', u.math.max(data))
print('median:', u.math.median(data))
sum: 25. V
mean: 5. V
std: 2.828427 V
min: 1. V
max: 9. V
median: 5. V
# Shape manipulation
M = jnp.arange(6.).reshape(2, 3) * u.ampere
print('original:', M)
print('reshape (3,2):', u.math.reshape(M, (3, 2)))
print('transpose:', u.math.transpose(M))
print('flatten:', u.math.flatten(M))
original: [[0. 1. 2.]
 [3. 4. 5.]] A
reshape (3,2): [[0. 1.]
 [2. 3.]
 [4. 5.]] A
transpose: [[0. 3.]
 [1. 4.]
 [2. 5.]] A
flatten: [0. 1. 2. 3. 4. 5.] A
# Concatenation and stacking
a = jnp.array([1., 2.]) * u.meter
b = jnp.array([3., 4.]) * u.meter
print('concatenate:', u.math.concatenate([a, b]))
print('stack:', u.math.stack([a, b]))
concatenate: [1. 2. 3. 4.] m
stack: [[1. 2.]
 [3. 4.]] m
# Sorting and ordering
unsorted = jnp.array([3., 1., 4., 1., 5.]) * u.pascal
print('sort:', u.math.sort(unsorted))
print('cumsum:', u.math.cumsum(unsorted))
sort: [1. 1. 3. 4. 5.] Pa
cumsum: [ 3.  4.  8.  9. 14.] Pa
# abs and negative
mixed = jnp.array([-2., 3., -1., 4.]) * u.newton
print('abs:', u.math.abs(mixed))
print('negative:', u.math.negative(mixed))
abs: [2. 3. 1. 4.] N
negative: [ 2. -3.  1. -4.] N
# clip (clamp values to a range)
x = jnp.array([0.5, 1.5, 2.5, 3.5, 4.5]) * u.volt
print('clip [1V, 3V]:', u.math.clip(x, 1.0 * u.volt, 3.0 * u.volt))
clip [1V, 3V]: [1.  1.5 2.5 3.  3. ] V
# diff: discrete derivative
t = jnp.array([0., 1., 4., 9., 16.]) * u.second
print('diff:', u.math.diff(t))
diff: [1. 3. 5. 7.] s
# where: conditional selection
cond = jnp.array([True, False, True, False])
a = jnp.array([1., 2., 3., 4.]) * u.meter
b = jnp.array([10., 20., 30., 40.]) * u.meter
print('where:', u.math.where(cond, a, b))
where: [ 1. 20.  3. 40.] m

5. Functions Removing Unit#

These functions return dimensionless results (booleans, indices, etc.).

Includes: equal, greater, less, isclose, allclose, argmax, argmin, argsort, nonzero, sign, all, any, count_nonzero, and more.

# Comparisons return boolean arrays (dimensionless)
a = jnp.array([1., 2., 3.]) * u.volt
b = jnp.array([2., 2., 2.]) * u.volt
print('equal:', u.math.equal(a, b))
print('greater:', u.math.greater(a, b))
print('less:', u.math.less(a, b))
equal: [False  True False]
greater: [False False  True]
less: [ True False False]
# isclose and allclose for approximate comparison
x = jnp.array([1.0, 2.0, 3.0]) * u.meter
y = jnp.array([1.0, 2.00001, 3.0]) * u.meter
print('isclose:', u.math.isclose(x, y))
print('allclose:', u.math.allclose(x, y))
isclose: [ True  True  True]
allclose: True
# Index-finding functions
data = jnp.array([10., 5., 30., 15., 25.]) * u.watt
print('argmax:', u.math.argmax(data))  # index of max
print('argmin:', u.math.argmin(data))  # index of min
print('argsort:', u.math.argsort(data))  # indices that would sort
argmax: 2
argmin: 1
argsort: [1 0 3 4 2]
# sign and logical operations
mixed = jnp.array([-3., 0., 5., -1., 2.]) * u.ampere
print('sign:', u.math.sign(mixed))
print('any > 0:', u.math.any(u.math.greater(mixed, 0 * u.ampere)))
print('all > 0:', u.math.all(u.math.greater(mixed, 0 * u.ampere)))
sign: [-1.  0.  1. -1.  1.]
any > 0: True
all > 0: False
# count_nonzero
sparse_data = jnp.array([0., 1., 0., 0., 3., 0., 2.]) * u.volt
print('count_nonzero:', u.math.count_nonzero(sparse_data))
count_nonzero: 3

Quick Reference: Unit Behavior by Category#

Category

Example Functions

Unit Behavior

Array Creation

zeros, ones, arange, linspace

Set via unit= parameter

Accept Unitless

sin, cos, exp, log

Input must be dimensionless

Change Unit

multiply, sqrt, dot, outer

Output unit derived from operation

Keep Unit

sum, mean, sort, reshape, clip

Same unit as input

Remove Unit

equal, argmax, sign, allclose

Dimensionless output

For the complete function listing, see the API documentation.