brainunit documentation#
brainunit provides physical units and unit-aware mathematical system in JAX for general AI-driven scientific computing.
The core features of brainunit include:
Integration of over 2,000 commonly used physical units and constants
Implementation of more than 500 unit-aware mathematical functions
Deep integration with JAX, providing comprehensive support for modern AI framework features including automatic differentiation (autograd), just-in-time compilation (JIT), vectorization, and parallel computation
Unit conversion and analysis are performed at compilation time, resulting in zero runtime overhead
Strict physical unit type checking and dimensional inference system, detecting unit inconsistencies during compilation
Compared to existing unit libraries, such as Quantities and Pint , brainunit introduces a rigorous physical unit system specifically designed to support AI computations (e.g., automatic differentiation, just-in-time compilation, and parallelization).
Installation#
pip install -U brainunit[cpu]
pip install -U brainunit[cuda12]
pip install -U brainunit[cuda13]
pip install -U brainunit[tpu]
Quick Start#
Most users of the brainunit package will work with Quantity: the combination of
a value and a unit. The most convenient way to create a Quantity is to multiply or
divide a value by one of the built-in units. It works with scalars, sequences,
and numpy or jax arrays.
import brainunit as u
61.8 * u.second
61.8 * second
[1., 2., 3.] * u.second
ArrayImpl([1. 2. 3.]) * second
import numpy as np
np.array([1., 2., 3.]) * u.second
ArrayImpl([1., 2., 3.]) * second
import jax.numpy as jnp
jnp.array([1., 2., 3.]) * u.second
ArrayImpl([1., 2., 3.]) * second
You can get the unit and mantissa from a Quantity using the unit and mantissa members:
q = 61.8 * u.second
q.mantissa
Array(61.8, dtype=float64, weak_type=True)
q.unit
second
You can also combine quantities or units:
15.1 * u.meter / (32.0 * u.second)
0.471875 * meter / second
3.0 * u.kmeter / (130.51 * u.meter / u.second)
0.022997 * (meter / second)
To create a dimensionless quantity, directly use the Quantity constructor:
q = u.Quantity(61.8)
q.dim
Dimension()