Unit-aware Computation with CustomArray#
Introduction#
The CustomArray class in brainunit provides a practical foundation for creating unit-aware computational arrays that maintain dimensional consistency throughout complex calculations. This tutorial shows how to use CustomArray to build array types that automatically handle units, enabling safer and more maintainable scientific computing.
What Is Unit-aware Computation?#
Unit-aware computation keeps physical quantities dimensionally correct across operations. Typical rules:
Adding meters to meters results in meters
Multiplying meters by meters results in square meters
Dividing distance by time results in velocity
Invalid operations (e.g., meters + seconds) are detected and raise errors
Why Use CustomArray?#
Type safety: Prevents dimensional errors at runtime
Automatic unit propagation through operations
Works with NumPy and JAX arrays (and supports PyTorch-like methods)
Extensible: create domain-specific array types (physics, neuroscience, etc.)
Minimal overhead compared to raw arrays
# Imports
import brainunit as u
import brainstate
print("brainunit version:", getattr(u, '__version__', 'unknown'))
print('Sample units:', 'm, s, Hz, V, A, kg, N, Pa, J')
brainunit version: 0.2.0
Sample units: m, s, Hz, V, A, kg, N, Pa, J
CustomArray Architecture#
CustomArray is a base class. Any class that inherits from it and provides a .data attribute automatically gains rich array behavior and unit-aware math via brainunit.math.
Core requirements:
Inherit from
u.CustomArrayStore your underlying data (with units) in
self.data
Benefits:
Separation of concerns: you focus on data/state,
CustomArrayhandles array opsUnit propagation: math operations keep correct units
Backend flexibility:
self.datacan be NumPy, JAX, or other array-likes
# A minimal, practical CustomArray
class MyArray(u.CustomArray):
"""Minimal unit-aware array: just store a `.data`."""
def __init__(self, data):
self.data = data # typically a brainunit Quantity or plain array
def __repr__(self):
return f'MyArray({self.data})'
# Create an instance with units
length = MyArray([1, 2, 3] * u.meter)
length, length.shape, getattr(length.data, 'unit', 'unitless')
(MyArray([1 2 3] m), (3,), Unit("m"))
Unit Propagation with Operators#
When .data is a Quantity, standard operations automatically keep or change units correctly.
# Compatible addition keeps units
length_cm = MyArray([100, 200, 300] * u.cmeter)
total_length = length + length_cm # meters + centimeters -> meters
print('total_length:', total_length)
total_length: [2. 4. 6.] m
# Multiplication changes units (area)
area = length * length # m * m -> m^2
print('area:', area)
area: [1 4 9] m^2
# Division changes units (velocity)
time = MyArray([1, 2, 3] * u.second)
velocity = length / time # m / s
print('velocity:', velocity)
velocity: [1. 1. 1.] m / s
# Incompatible addition raises an error
try:
bad = length + time
except Exception as e:
print('Expected error:', e)
Expected error: Cannot calculate
[1 2 3] m + [1 2 3] s, because units do not match: m != s
Using brainunit.math with CustomArray#
The brainunit.math module mirrors NumPy/JAX APIs and is unit-aware. All functions accept CustomArray instances: internally, brainunit extracts .data via helper utilities and returns quantities with correct units.
Categories (simplified):
Keep-unit functions (e.g.,
mean,sum,concatenate,stack) return the same unitChange-unit functions (e.g.,
square,sqrt,multiply,divide,var) transform units according to math rulesSome functions require unitless inputs (e.g.,
round,floor)
# Keep-unit examples
print('mean(length):', u.math.mean(length))
print('sum(length):', u.math.sum(length))
# Change-unit examples
print('square(length):', u.math.square(length)) # m^2
print('sqrt(square(length)):', u.math.sqrt(u.math.square(length))) # back to m
print('var(length):', u.math.var(length)) # m^2
# Broadcasting and stacking
stacked = u.math.stack([length, length_cm])
print('stacked shape:', getattr(stacked, 'shape', None))
# Linear algebra with units
force = MyArray([10, 20, 30] * u.newton)
displacement = MyArray([0.5, 1.0, 1.5] * u.meter)
work = u.math.dot(force, displacement) # N·m -> J (joule)
print('work (dot):', work)
mean(length): 2. m
sum(length): 6 m
square(length): [1 4 9] m^2
sqrt(square(length)): [1. 2. 3.] m
var(length): 0.6666667 m^2
stacked shape: (2, 3)
work (dot): 70. J
Converting Units for Display or Interop#
Use Quantity.to_decimal(target_unit) to get values in a desired unit scale for display, logging, or plotting.
# Convert quantity values inside your CustomArray for display
meters = MyArray([1, 2, 3] * u.meter)
print('as meters:', meters.data.to_decimal(u.meter))
print('as centimeters:', meters.data.to_decimal(u.cmeter))
as meters: [1 2 3]
as centimeters: [100. 200. 300.]
Stateful and Learnable Arrays (BrainState)#
For stateful workflows, combine brainstate.State with CustomArray to create learnable, unit-aware parameters.
class StatefulArray(brainstate.State, u.CustomArray):
@property
def data(self):
return self.value
# Example: a learnable parameter with units
param = StatefulArray(0.1 * u.second)
print('stateful param:', param)
stateful param: 0.1 s
Robust Patterns and Error Handling#
Tips:
Document expected units for each array (e.g., meters for length)
Validate inputs when building domain-specific types
Catch and surface unit mismatch errors with clear messages
Prefer
brainunit.mathover raw NumPy for unit-aware operations
Summary#
Inherit from
u.CustomArrayand setself.data(often aQuantity)Use operators and
brainunit.mathto get automatic unit propagationConvert units with
Quantity.to_decimalfor display or interopCombine with BrainState to build stateful, unit-aware components
With these patterns, you can build reliable, unit-safe computational workflows across NumPy and JAX backends.