Unit-aware Computation with CustomArray#

Colab Open in Kaggle

Introduction#

The CustomArray class in brainunit provides a practical foundation for creating unit-aware computational arrays that maintain dimensional consistency throughout complex calculations. This tutorial shows how to use CustomArray to build array types that automatically handle units, enabling safer and more maintainable scientific computing.

What Is Unit-aware Computation?#

Unit-aware computation keeps physical quantities dimensionally correct across operations. Typical rules:

  • Adding meters to meters results in meters

  • Multiplying meters by meters results in square meters

  • Dividing distance by time results in velocity

  • Invalid operations (e.g., meters + seconds) are detected and raise errors

Why Use CustomArray?#

  • Type safety: Prevents dimensional errors at runtime

  • Automatic unit propagation through operations

  • Works with NumPy and JAX arrays (and supports PyTorch-like methods)

  • Extensible: create domain-specific array types (physics, neuroscience, etc.)

  • Minimal overhead compared to raw arrays

# Imports
import brainunit as u
import brainstate

print("brainunit version:", getattr(u, '__version__', 'unknown'))
print('Sample units:', 'm, s, Hz, V, A, kg, N, Pa, J')
brainunit version: 0.2.0
Sample units: m, s, Hz, V, A, kg, N, Pa, J

CustomArray Architecture#

CustomArray is a base class. Any class that inherits from it and provides a .data attribute automatically gains rich array behavior and unit-aware math via brainunit.math.

Core requirements:

  1. Inherit from u.CustomArray

  2. Store your underlying data (with units) in self.data

Benefits:

  • Separation of concerns: you focus on data/state, CustomArray handles array ops

  • Unit propagation: math operations keep correct units

  • Backend flexibility: self.data can be NumPy, JAX, or other array-likes

# A minimal, practical CustomArray
class MyArray(u.CustomArray):
    """Minimal unit-aware array: just store a `.data`."""
    def __init__(self, data):
        self.data = data  # typically a brainunit Quantity or plain array
    def __repr__(self):
        return f'MyArray({self.data})'

# Create an instance with units
length = MyArray([1, 2, 3] * u.meter)
length, length.shape, getattr(length.data, 'unit', 'unitless')
(MyArray([1 2 3] m), (3,), Unit("m"))

Unit Propagation with Operators#

When .data is a Quantity, standard operations automatically keep or change units correctly.

# Compatible addition keeps units
length_cm = MyArray([100, 200, 300] * u.cmeter)
total_length = length + length_cm  # meters + centimeters -> meters
print('total_length:', total_length)
total_length: [2. 4. 6.] m
# Multiplication changes units (area)
area = length * length  # m * m -> m^2
print('area:', area)
area: [1 4 9] m^2
# Division changes units (velocity)
time = MyArray([1, 2, 3] * u.second)
velocity = length / time  # m / s
print('velocity:', velocity)
velocity: [1. 1. 1.] m / s
# Incompatible addition raises an error
try:
    bad = length + time
except Exception as e:
    print('Expected error:', e)
Expected error: Cannot calculate 
[1 2 3] m + [1 2 3] s, because units do not match: m != s

Using brainunit.math with CustomArray#

The brainunit.math module mirrors NumPy/JAX APIs and is unit-aware. All functions accept CustomArray instances: internally, brainunit extracts .data via helper utilities and returns quantities with correct units.

Categories (simplified):

  • Keep-unit functions (e.g., mean, sum, concatenate, stack) return the same unit

  • Change-unit functions (e.g., square, sqrt, multiply, divide, var) transform units according to math rules

  • Some functions require unitless inputs (e.g., round, floor)

# Keep-unit examples
print('mean(length):', u.math.mean(length))
print('sum(length):', u.math.sum(length))

# Change-unit examples
print('square(length):', u.math.square(length))  # m^2
print('sqrt(square(length)):', u.math.sqrt(u.math.square(length)))  # back to m
print('var(length):', u.math.var(length))  # m^2

# Broadcasting and stacking
stacked = u.math.stack([length, length_cm])
print('stacked shape:', getattr(stacked, 'shape', None))

# Linear algebra with units
force = MyArray([10, 20, 30] * u.newton)
displacement = MyArray([0.5, 1.0, 1.5] * u.meter)
work = u.math.dot(force, displacement)  # N·m -> J (joule)
print('work (dot):', work)
mean(length): 2. m
sum(length): 6 m
square(length): [1 4 9] m^2
sqrt(square(length)): [1. 2. 3.] m
var(length): 0.6666667 m^2
stacked shape: (2, 3)
work (dot): 70. J

Converting Units for Display or Interop#

Use Quantity.to_decimal(target_unit) to get values in a desired unit scale for display, logging, or plotting.

# Convert quantity values inside your CustomArray for display
meters = MyArray([1, 2, 3] * u.meter)
print('as meters:', meters.data.to_decimal(u.meter))
print('as centimeters:', meters.data.to_decimal(u.cmeter))
as meters: [1 2 3]
as centimeters: [100. 200. 300.]

Stateful and Learnable Arrays (BrainState)#

For stateful workflows, combine brainstate.State with CustomArray to create learnable, unit-aware parameters.

class StatefulArray(brainstate.State, u.CustomArray):
    @property
    def data(self):
        return self.value

# Example: a learnable parameter with units
param = StatefulArray(0.1 * u.second)
print('stateful param:', param)
stateful param: 0.1 s

Robust Patterns and Error Handling#

Tips:

  • Document expected units for each array (e.g., meters for length)

  • Validate inputs when building domain-specific types

  • Catch and surface unit mismatch errors with clear messages

  • Prefer brainunit.math over raw NumPy for unit-aware operations

Summary#

  • Inherit from u.CustomArray and set self.data (often a Quantity)

  • Use operators and brainunit.math to get automatic unit propagation

  • Convert units with Quantity.to_decimal for display or interop

  • Combine with BrainState to build stateful, unit-aware components

With these patterns, you can build reliable, unit-safe computational workflows across NumPy and JAX backends.