NumPy Functions#
brainunit.math provides 500+ unit-aware functions compatible with NumPy/JAX.
They are categorized by how they handle units:
Array Creation — create arrays with or without units
Functions Accepting Unitless — require dimensionless input (trig, exp, log)
Functions Changing Unit — output unit differs from input (multiply, sqrt, dot)
Functions Keeping Unit — output unit matches input (sort, sum, reshape)
Functions Removing Unit — return dimensionless results (comparisons, argmax)
import brainunit as u
import jax.numpy as jnp
1. Array Creation#
Functions for creating arrays with units.
Includes: array, asarray, zeros, ones, full, eye, arange, linspace, logspace, meshgrid, and their _like variants.
# zeros and ones with unit
print('zeros:', u.math.zeros(3, unit=u.volt))
print('ones:', u.math.ones((2, 2), unit=u.meter))
zeros: [0. 0. 0.] V
ones: [[1. 1.]
[1. 1.]] m
# arange with Quantity endpoints
times = u.math.arange(0 * u.ms, 10 * u.ms, step=2 * u.ms)
print('arange:', times)
arange:
[0 2 4 6 8] ms
# linspace between Quantity endpoints
voltages = u.math.linspace(0 * u.mV, 100 * u.mV, 5)
print('linspace:', voltages)
linspace: [ 0. 25. 50. 75. 100.] mV
# eye with unit
print('eye:')
print(u.math.eye(3, unit=u.ohm))
eye:
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]] ohm
# full_like: create array with same shape as a Quantity
template = jnp.array([1., 2., 3.]) * u.newton
print('full_like:', u.math.full_like(template, 99.0 * u.newton))
full_like: [99. 99. 99.] N
2. Functions Accepting Unitless#
These functions require dimensionless inputs. They perform transcendental or trigonometric operations that are not physically meaningful with units.
Includes: exp, log, sin, cos, tan, arcsin, arctan2, sinh, cosh, deg2rad, rad2deg, logaddexp, and more.
# Trigonometric functions on dimensionless values
angles = jnp.array([0., jnp.pi/6, jnp.pi/4, jnp.pi/3, jnp.pi/2])
print('sin:', u.math.sin(angles))
print('cos:', u.math.cos(angles))
sin: [0. 0.5 0.70710677 0.86602545 1. ]
cos: [ 1.0000000e+00 8.6602539e-01 7.0710677e-01 4.9999997e-01
-4.3711388e-08]
# Exponential and logarithm
x = jnp.array([0., 1., 2.])
print('exp:', u.math.exp(x))
print('log:', u.math.log(u.math.exp(x))) # round-trip
exp: [1. 2.7182817 7.389056 ]
log: [0. 0.99999994 2. ]
# Passing a Quantity with units raises an error
try:
u.math.exp(2.0 * u.meter)
except Exception as e:
print(type(e).__name__, ':', e)
TypeError : exp requires a dimensionless "x" when "unit_to_scale" is not provided. Got Quantity(unit=m, dim=m). Pass "unit_to_scale=<Unit>" to scale before applying exp, or convert explicitly to a dimensionless value first.
# Convert to dimensionless first using to_decimal
ratio = (5.0 * u.mV).to_decimal(u.volt) # 0.005, dimensionless
print('exp(ratio):', u.math.exp(ratio))
exp(ratio): 1.0050125
# Angle conversions
print('deg2rad(180):', u.math.deg2rad(180.0))
print('rad2deg(pi):', u.math.rad2deg(jnp.pi))
deg2rad(180): 3.1415927
rad2deg(pi): 180.0
# arctan2 for 2D angle computation
y = jnp.array([1., 0., -1.])
x = jnp.array([0., 1., 0.])
print('arctan2:', u.math.arctan2(y, x))
arctan2: [ 1.5707964 0. -1.5707964]
3. Functions Changing Unit#
These functions produce outputs with different units than their inputs, following the mathematical rules of the operation.
Includes: multiply, divide, power, sqrt, square, reciprocal, prod, dot, matmul, inner, outer, kron, cross, convolve, and more.
# multiply: units multiply
force = jnp.array([10., 20.]) * u.newton
distance = jnp.array([5., 3.]) * u.meter
work = u.math.multiply(force, distance)
print('F * d (work):', work) # N * m = J
F * d (work):
[50. 60.] J
# divide: units divide
print('d / t (speed):', u.math.divide(100.0 * u.meter, 10.0 * u.second))
d / t (speed): 10. m / s
# sqrt and square
area = jnp.array([4., 9., 16.]) * u.meter2
print('sqrt(area):', u.math.sqrt(area)) # m^2 -> m
side = jnp.array([2., 3., 4.]) * u.meter
print('square(side):', u.math.square(side)) # m -> m^2
sqrt(area): [2. 3. 4.] m
square(side): [ 4. 9. 16.] m^2
# reciprocal
resistance = jnp.array([10., 50.]) * u.ohm
print('1/R (conductance):', u.math.reciprocal(resistance)) # 1/ohm = S
1/R (conductance): [0.1 0.02] S
# dot product
a = jnp.array([1., 2., 3.]) * u.meter
b = jnp.array([4., 5., 6.]) * u.meter
print('dot(a, b):', u.math.dot(a, b)) # m^2
dot(a, b): 32. m^2
# outer product
x = jnp.array([1., 2.]) * u.volt
y = jnp.array([3., 4., 5.]) * u.ampere
print('outer(V, A):')
print(u.math.outer(x, y)) # V * A = W
outer(V, A):
[[ 3. 4. 5.]
[ 6. 8. 10.]] W
# prod: unit raised to the power of number of elements
vals = jnp.array([2., 3., 4.]) * u.meter
print('prod:', u.math.prod(vals)) # 24 m^3
prod: 24. m^3
4. Functions Keeping Unit#
These functions preserve the unit of the input in the output.
Includes: sum, mean, std, min, max, sort, cumsum, reshape, transpose, concatenate, stack, split, flip, roll, clip, abs, negative, diff, interp, where, unique, and many more.
# Statistical functions
data = jnp.array([1., 3., 5., 7., 9.]) * u.volt
print('sum:', u.math.sum(data))
print('mean:', u.math.mean(data))
print('std:', u.math.std(data))
print('min:', u.math.min(data))
print('max:', u.math.max(data))
print('median:', u.math.median(data))
sum: 25. V
mean: 5. V
std: 2.828427 V
min: 1. V
max: 9. V
median: 5. V
# Shape manipulation
M = jnp.arange(6.).reshape(2, 3) * u.ampere
print('original:', M)
print('reshape (3,2):', u.math.reshape(M, (3, 2)))
print('transpose:', u.math.transpose(M))
print('flatten:', u.math.flatten(M))
original: [[0. 1. 2.]
[3. 4. 5.]] A
reshape (3,2): [[0. 1.]
[2. 3.]
[4. 5.]] A
transpose: [[0. 3.]
[1. 4.]
[2. 5.]] A
flatten: [0. 1. 2. 3. 4. 5.] A
# Concatenation and stacking
a = jnp.array([1., 2.]) * u.meter
b = jnp.array([3., 4.]) * u.meter
print('concatenate:', u.math.concatenate([a, b]))
print('stack:', u.math.stack([a, b]))
concatenate: [1. 2. 3. 4.] m
stack: [[1. 2.]
[3. 4.]] m
# Sorting and ordering
unsorted = jnp.array([3., 1., 4., 1., 5.]) * u.pascal
print('sort:', u.math.sort(unsorted))
print('cumsum:', u.math.cumsum(unsorted))
sort: [1. 1. 3. 4. 5.] Pa
cumsum: [ 3. 4. 8. 9. 14.] Pa
# abs and negative
mixed = jnp.array([-2., 3., -1., 4.]) * u.newton
print('abs:', u.math.abs(mixed))
print('negative:', u.math.negative(mixed))
abs: [2. 3. 1. 4.] N
negative: [ 2. -3. 1. -4.] N
# clip (clamp values to a range)
x = jnp.array([0.5, 1.5, 2.5, 3.5, 4.5]) * u.volt
print('clip [1V, 3V]:', u.math.clip(x, 1.0 * u.volt, 3.0 * u.volt))
clip [1V, 3V]: [1. 1.5 2.5 3. 3. ] V
# diff: discrete derivative
t = jnp.array([0., 1., 4., 9., 16.]) * u.second
print('diff:', u.math.diff(t))
diff: [1. 3. 5. 7.] s
# where: conditional selection
cond = jnp.array([True, False, True, False])
a = jnp.array([1., 2., 3., 4.]) * u.meter
b = jnp.array([10., 20., 30., 40.]) * u.meter
print('where:', u.math.where(cond, a, b))
where: [ 1. 20. 3. 40.] m
5. Functions Removing Unit#
These functions return dimensionless results (booleans, indices, etc.).
Includes: equal, greater, less, isclose, allclose, argmax, argmin, argsort, nonzero, sign, all, any, count_nonzero, and more.
# Comparisons return boolean arrays (dimensionless)
a = jnp.array([1., 2., 3.]) * u.volt
b = jnp.array([2., 2., 2.]) * u.volt
print('equal:', u.math.equal(a, b))
print('greater:', u.math.greater(a, b))
print('less:', u.math.less(a, b))
equal: [False True False]
greater: [False False True]
less: [ True False False]
# isclose and allclose for approximate comparison
x = jnp.array([1.0, 2.0, 3.0]) * u.meter
y = jnp.array([1.0, 2.00001, 3.0]) * u.meter
print('isclose:', u.math.isclose(x, y))
print('allclose:', u.math.allclose(x, y))
isclose: [ True True True]
allclose: True
# Index-finding functions
data = jnp.array([10., 5., 30., 15., 25.]) * u.watt
print('argmax:', u.math.argmax(data)) # index of max
print('argmin:', u.math.argmin(data)) # index of min
print('argsort:', u.math.argsort(data)) # indices that would sort
argmax: 2
argmin: 1
argsort: [1 0 3 4 2]
# sign and logical operations
mixed = jnp.array([-3., 0., 5., -1., 2.]) * u.ampere
print('sign:', u.math.sign(mixed))
print('any > 0:', u.math.any(u.math.greater(mixed, 0 * u.ampere)))
print('all > 0:', u.math.all(u.math.greater(mixed, 0 * u.ampere)))
sign: [-1. 0. 1. -1. 1.]
any > 0: True
all > 0: False
# count_nonzero
sparse_data = jnp.array([0., 1., 0., 0., 3., 0., 2.]) * u.volt
print('count_nonzero:', u.math.count_nonzero(sparse_data))
count_nonzero: 3
Quick Reference: Unit Behavior by Category#
Category |
Example Functions |
Unit Behavior |
|---|---|---|
Array Creation |
|
Set via |
Accept Unitless |
|
Input must be dimensionless |
Change Unit |
|
Output unit derived from operation |
Keep Unit |
|
Same unit as input |
Remove Unit |
|
Dimensionless output |
For the complete function listing, see the API documentation.