Quickstart Guide#
brainunit is a unit-aware scientific computing library built on JAX. It tracks physical units through all computations — arithmetic, linear algebra, FFTs, automatic differentiation, and JIT compilation — catching dimension errors at runtime.
This guide covers the essentials in 5 minutes.
Installation#
pip install brainunit
import brainunit as u
import jax
import jax.numpy as jnp
Creating Quantities#
A Quantity = numeric value + physical unit. Create one by multiplying a value with a unit.
# Scalars
mass = 5.0 * u.kilogram
speed = 10.0 * u.meter / u.second
print('mass:', mass)
print('speed:', speed)
mass: 5. kg
speed: 10. m / s
# Arrays
voltages = jnp.array([1.0, 2.5, 3.7]) * u.mV
print('voltages:', voltages)
print('shape:', voltages.shape, 'dtype:', voltages.dtype)
voltages: [1. 2.5 3.70000005] mV
shape: (3,) dtype: float32
# Direct construction
current = u.Quantity(jnp.array([0.1, 0.2, 0.3]), unit=u.ampere)
print('current:', current)
current: [0.1 0.2 0.30000001] A
Arithmetic with Units#
Units are tracked automatically. Incompatible operations raise errors.
# Addition: same dimension required
t1 = 500.0 * u.ms
t2 = 1.5 * u.second
print('t1 + t2:', t1 + t2) # auto-aligns to first unit
t1 + t2: 2000. ms
# Multiplication: units multiply
F = 10.0 * u.newton
d = 3.0 * u.meter
print('work = F * d:', F * d) # N * m = J
work = F * d: 30. J
# Division: units divide
print('speed = d / t:', (100.0 * u.meter) / (10.0 * u.second)) # m/s
speed = d / t: 10. m / s
# Dimension mismatch raises error
try:
result = 5.0 * u.meter + 3.0 * u.second
except Exception as e:
print('Error:', e)
Error: Cannot calculate
5. m + 3. s, because units do not match: m != s
Unit Conversion#
Use to_decimal() to extract the numeric value in a target unit,
or in_unit() to get a new Quantity in the target unit.
distance = 2.5 * u.kmeter
print('In meters:', distance.to_decimal(u.meter)) # 2500.0
print('In cm:', distance.to_decimal(u.cmeter)) # 250000.0
print('As Quantity:', distance.in_unit(u.meter)) # 2500.0 m
In meters: 2500.0
In cm: 250000.0
As Quantity: 2500. m
Quantity Attributes#
q = jnp.array([[1., 2.], [3., 4.]]) * u.volt
print('mantissa:', q.mantissa) # numeric array
print('unit:', q.unit) # the unit
print('dim:', q.dim) # physical dimension
print('shape:', q.shape) # array shape
print('dtype:', q.dtype) # array dtype
mantissa: [[1. 2.]
[3. 4.]]
unit: V
dim: m^2 kg s^-3 A^-1
shape: (2, 2)
dtype: float32
Unit-Aware Math Functions#
brainunit.math provides 500+ functions that understand units.
data = jnp.array([2., 4., 6., 8., 10.]) * u.newton
print('sum:', u.math.sum(data)) # keeps unit
print('mean:', u.math.mean(data)) # keeps unit
print('sqrt:', u.math.sqrt(4.0 * u.meter2)) # changes unit: m^2 -> m
print('sort:', u.math.sort(jnp.array([3., 1., 2.]) * u.volt))
sum: 30. N
mean: 6. N
sqrt: 2. m
sort: [1. 2. 3.] V
Physical Constants#
from brainunit import constants
print('Avogadro number:', constants.avogadro)
print('Boltzmann constant:', constants.boltzmann)
print('Elementary charge:', constants.elementary_charge)
print('Electron mass:', constants.electron_mass)
Avogadro number: 6.0221406e+23 1 / mol
Boltzmann constant: 1.380649e-23 J / K
Elementary charge: 1.6021766e-19 C
Electron mass: 9.109383e-31 kg
JAX Transforms: jit, vmap, grad#
Quantities work seamlessly with JAX transformations.
# JIT compilation
@jax.jit
def kinetic_energy(m, v):
return 0.5 * m * v**2
KE = kinetic_energy(2.0 * u.kilogram, 3.0 * u.meter / u.second)
print('KE =', KE) # kg * m^2 / s^2 = J
KE = 9. J
# vmap: vectorize over a batch
velocities = jnp.array([1., 2., 3., 4., 5.]) * u.meter / u.second
energies = jax.vmap(lambda v: kinetic_energy(2.0 * u.kilogram, v))(velocities)
print('Batch KE:', energies)
Batch KE: [ 1. 4. 9. 16. 25.] J
# grad: automatic differentiation with unit tracking
dKE_dv = u.autograd.grad(lambda v: 0.5 * (2.0 * u.kilogram) * v**2)
print('dKE/dv at v=3 m/s:', dKE_dv(3.0 * u.meter / u.second)) # momentum: kg * m/s
dKE/dv at v=3 m/s: 6. kg * m / s
Unit Validation with Decorators#
Use @check_units to enforce unit contracts on function arguments.
@u.check_units(v=u.meter / u.second, t=u.second)
def displacement(v, t):
return v * t
print('displacement:', displacement(10.0 * u.meter / u.second, 5.0 * u.second))
displacement: 50. m
# Wrong units raise an error
try:
displacement(10.0 * u.kilogram, 5.0 * u.second)
except Exception as e:
print('Error:', e)
Error: Function 'displacement' expected a array with unit Unit("m / s") for argument 'v' but got '10. kg' (unit is kg).
What’s Next?#
Quantity — Creating and manipulating quantities in depth
Standard Units — All available SI and non-SI units
Unit Conversion — Converting between units
NumPy Functions — 500+ unit-aware math functions
Linear Algebra — Unit-aware linalg
Unit Validation — check_dims and check_units decorators