brainunit documentation

brainunit documentation#

brainunit provides physical units and unit-aware mathematical system in JAX for general AI-driven scientific computing.

The core features of brainunit include:

  • Integration of over 2,000 commonly used physical units and constants

  • Implementation of more than 500 unit-aware mathematical functions

  • Deep integration with JAX, providing comprehensive support for modern AI framework features including automatic differentiation (autograd), just-in-time compilation (JIT), vectorization, and parallel computation

  • Unit conversion and analysis are performed at compilation time, resulting in zero runtime overhead

  • Strict physical unit type checking and dimensional inference system, detecting unit inconsistencies during compilation

Compared to existing unit libraries, such as Quantities and Pint , brainunit introduces a rigorous physical unit system specifically designed to support AI computations (e.g., automatic differentiation, just-in-time compilation, and parallelization).


Installation#

pip install -U brainunit[cpu]
pip install -U brainunit[cuda12]
pip install -U brainunit[cuda13]
pip install -U brainunit[tpu]

Quick Start#

Most users of the brainunit package will work with Quantity: the combination of a value and a unit. The most convenient way to create a Quantity is to multiply or divide a value by one of the built-in units. It works with scalars, sequences, and numpy or jax arrays.

import brainunit as u
61.8 * u.second
61.8 * second
[1., 2., 3.] * u.second
ArrayImpl([1. 2. 3.]) * second
import numpy as np
np.array([1., 2., 3.]) * u.second
ArrayImpl([1., 2., 3.]) * second
import jax.numpy as jnp
jnp.array([1., 2., 3.]) * u.second
ArrayImpl([1., 2., 3.]) * second

You can get the unit and mantissa from a Quantity using the unit and mantissa members:

q = 61.8 * u.second
q.mantissa
Array(61.8, dtype=float64, weak_type=True)
q.unit
second

You can also combine quantities or units:

15.1 * u.meter / (32.0 * u.second)
0.471875 * meter / second
3.0 * u.kmeter / (130.51 * u.meter / u.second)
0.022997 * (meter / second)

To create a dimensionless quantity, directly use the Quantity constructor:

q = u.Quantity(61.8)
q.dim
Dimension()