Backends overview#

brainunit pairs a physical Unit with an array mantissa. The mantissa can live on any one of several array backends, and every unit-aware operation (brainunit.math, brainunit.linalg, brainunit.fft, plain arithmetic) dispatches to the matching backend’s array library. You can stay inside one backend end-to-end, mix them, or switch by calling a single conversion method.

This page describes the architecture, the supported backends, how selection works, what each backend can and cannot do, and how to install optional backend dependencies. Per-backend notebooks follow it: see JAX, NumPy, CuPy, PyTorch, Dask, and ndonnx.

Supported backends#

Backend

Mantissa type

Optional install

Typical use case

jax

jax.Array

required (core)

autograd, JIT, vmap, accelerators (default)

numpy

numpy.ndarray

required (core)

scipy / pandas / sklearn interop, CPU

cupy

cupy.ndarray

brainunit[cupy]

NVIDIA GPU arrays, drop-in NumPy replacement

torch

torch.Tensor

brainunit[torch]

PyTorch models, CUDA/MPS tensors

dask

dask.array.Array

brainunit[dask]

out-of-core / parallel arrays, lazy compute

ndonnx

ndonnx.Array

brainunit[ndonnx]

symbolic graph building, ONNX export

jax and numpy are always available because both are required core dependencies. The other four are opt-in: if you do not install the extra, brainunit still works — it just refuses to dispatch onto that backend and raises brainunit.BackendError with the matching pip install hint when you ask for one explicitly.

Internally, the numpy, cupy, torch, and dask namespaces are sourced from array_api_compat so they all expose the same array-API-standard surface. jax.numpy (JAX ≥ 0.9) and ndonnx are array-API compatible on their own and are used unwrapped.

How backend selection works#

For every operation, brainunit asks: “which array library should compute the result?” The rule:

  1. Inspect the input mantissas. If exactly one backend kind is present, use it.

  2. If inputs mix backends, or there are no array inputs, consult the thread-local default backend set by set_default_backend(...) / using_backend(...).

  3. If no default is set, fall back to jax (the historical default).

import numpy as np
import jax.numpy as jnp
import brainunit as u

q_np  = u.Quantity(np.array([1.0]), unit=u.meter)
q_jax = u.Quantity(jnp.array([2.0]), unit=u.meter)

print('q_np.backend        =', q_np.backend)
print('q_jax.backend       =', q_jax.backend)
print('(q_np + q_np).bk    =', (q_np + q_np).backend)   # single -> wins
print('(q_np + q_jax).bk   =', (q_np + q_jax).backend)  # mixed  -> default
q_np.backend        = numpy
q_jax.backend       = jax
(q_np + q_np).bk    = numpy
(q_np + q_jax).bk   = jax

Override the tiebreaker with the context manager:

with u.using_backend('numpy'):
    print('inside using_backend:', (q_np + q_jax).backend)  # 'numpy'

print('outside:', (q_np + q_jax).backend)                    # back to 'jax'
inside using_backend: jax
outside: jax

Or set it for the rest of the program:

u.set_default_backend('numpy')
print(u.get_default_backend())
print((q_np + q_jax).backend)

u.set_default_backend(None)   # restore default
print(u.get_default_backend())
numpy
jax
None

The default is a ContextVar, so it isolates per-thread and per-task; nested using_backend(...) blocks restore the prior value on exit.

Choosing a backend#

There is no universally best backend — each one trades capability against ecosystem.

  • jax — pick this when you need automatic differentiation, JIT, vmap, or accelerator support out of the box. This is the default and the most fully integrated backend; everything in brainunit.autograd, brainunit.lax, and brainunit.sparse requires it.

  • numpy — pick this for interop with the broader scientific Python stack (scipy, pandas, sklearn, matplotlib) where you want eager results with no JAX tracing. Works on CPU only.

  • cupy — pick this when you want a near-drop-in NumPy replacement running on an NVIDIA GPU and you don’t need autodiff. Requires a CUDA toolkit.

  • torch — pick this to embed unit-aware computations inside an existing PyTorch model. PyTorch’s own autograd is preserved through brainunit ops, so loss.backward() works on a quantity-derived loss. brainunit.autograd itself is JAX-only — call torch.autograd.grad on the mantissa.

  • dask — pick this for arrays that don’t fit in memory, or for embarrassingly parallel array work on a cluster. Operations stay lazy until you call .compute().

  • ndonnx — pick this when you want to build an ONNX graph symbolically. Operations build the graph rather than executing eagerly. Still maturing: not every brainunit operation has an ndonnx implementation.

Backend capabilities and limitations#

Dimensional analysis works on every backend — brainunit tracks units on the Python Quantity object, independent of the mantissa library. The limitations below describe what each array backend can and cannot do, not the unit system.

jax (default)#

Full feature set. The only backend that supports:

  • brainunit.lax.* — wrappers over jax.lax primitives.

  • brainunit.autograd.*grad, jacobian, hessian.

  • brainunit.sparse.*CSR, CSC, COO sparse matrices.

  • jax.jit, jax.vmap, jax.pmap over quantities.

numpy#

Eager CPU computation. brainunit.math, brainunit.linalg, and brainunit.fft all work. JAX-specific subpackages raise BackendError.

cupy#

NVIDIA GPU arrays via CUDA. Same general capability as numpy for brainunit.math / brainunit.linalg / brainunit.fft, but executed on the GPU. No autograd, no JIT, no brainunit.lax.

torch#

PyTorch tensors. brainunit.math / brainunit.linalg / brainunit.fft route through array_api_compat.torch. Use torch.autograd.grad on the mantissa when you need backward passes — brainunit.autograd is JAX-only.

dask#

Lazy arrays. Building a quantity, inspecting .shape / .ndim / .dtype, arithmetic, and most array-API operations stay lazy. Operations that need a concrete Python value — float(q), int(q), q.tolist(), np.asarray(q), hash(q), operator.index(q) — raise BackendError; call q.mantissa.compute() first.

ndonnx#

Symbolic / ONNX graph building. Routing is correct for the array-API operations that ndonnx implements. Operations ndonnx hasn’t implemented yet surface their own errors unwrapped (brainunit does not catch them). Unit information lives on the Quantity and is not encoded in the ONNX graph.

Example of a JAX-only operation refusing a NumPy mantissa:

from brainunit import BackendError

q_np = u.Quantity(np.array([1.0, 2.0, 3.0]), unit=u.meter)
try:
    u.lax.slice(q_np, (0,), (1,))
except BackendError as exc:
    print('expected:', exc)

# convert and retry
print(u.lax.slice(q_np.to_jax(), (0,), (1,)))
expected: brainunit.lax.slice requires the jax backend; got numpy-backed Quantity. Call .to_jax() on the input first.
[1.] m

Optional dependencies and graceful failure#

Optional backends are detected lazily. The is_*_array helpers cache ImportError for the lifetime of the process and never raise:

print('is_jax_array(jnp.zeros(1))      =', u.is_jax_array(jnp.zeros(1)))
print('is_numpy_array(np.zeros(1))     =', u.is_numpy_array(np.zeros(1)))
print('is_cupy_array on a non-cupy obj =', u.is_cupy_array([1, 2, 3]))
print('is_torch_array on a non-torch   =', u.is_torch_array([1, 2, 3]))
print('is_dask_array on a non-dask     =', u.is_dask_array([1, 2, 3]))
print('is_ndonnx_array on non-ndonnx   =', u.is_ndonnx_array([1, 2, 3]))
is_jax_array(jnp.zeros(1))      = True
is_numpy_array(np.zeros(1))     = True
is_cupy_array on a non-cupy obj = False
is_torch_array on a non-torch   = False
is_dask_array on a non-dask     = False
is_ndonnx_array on non-ndonnx   = False

Asking for a backend that isn’t installed raises brainunit.BackendError, not a bare ImportError. The exception message includes the exact install command, so guard around the selection if you want graceful fallback:

def pick_backend():
    for name, module in [('torch', 'torch'), ('cupy', 'cupy'),
                          ('jax', 'jax'),    ('numpy', 'numpy')]:
        try:
            __import__(module)
            return name
        except ImportError:
            continue
    raise RuntimeError('no array backend available')

print('preferred backend:', pick_backend())
preferred backend: torch

Conversion between backends#

Every Quantity has a per-backend conversion method. Each one returns a new Quantity; the original is untouched. Each one is a no-op (return self) if the mantissa is already on the target backend.

Method

Notes

q.to_jax()

Wraps the mantissa with jnp.asarray.

q.to_numpy()

Materializes ndonnx via unwrap_numpy.

q.to_cupy(device=None)

device is a CUDA device index.

q.to_torch(device=None, dtype=None)

dtype accepts numpy or torch dtypes.

q.to_dask(chunks='auto')

Wraps with dask.array.from_array.

q.to_ndonnx()

ndonnx.asarray on the mantissa.

q_np  = u.Quantity(np.array([1.0, 2.0]), unit=u.meter)
q_jax = q_np.to_jax()         # NumPy -> JAX
q_back = q_jax.to_numpy()     # JAX  -> NumPy

print(q_np.backend, '->', q_jax.backend, '->', q_back.backend)
numpy -> jax -> numpy

Installation#

Command

Provides

pip install brainunit

core Quantity, jax + numpy backends

pip install brainunit[cpu]

core + jax[cpu] (pinned CPU wheels)

pip install brainunit[cuda12]

core + jax[cuda12]

pip install brainunit[cuda13]

core + jax[cuda13]

pip install brainunit[tpu]

core + jax[tpu]

pip install brainunit[cupy]

adds cupy-cuda12x for the CuPy backend

pip install brainunit[torch]

adds torch>=2.0 for the PyTorch backend

pip install brainunit[dask]

adds dask[array] for the Dask backend

pip install brainunit[ndonnx]

adds ndonnx for the symbolic backend

pip install brainunit[all]

shorthand for [cupy,torch,dask,ndonnx]

JAX is a required dependency — every install includes the JAX backend. The [cpu] / [cuda12] / [cuda13] / [tpu] extras pin the JAX accelerator build; pick at most one. The [cupy] / [torch] / [dask] / [ndonnx] extras are independent and can be combined freely.

See also#

  • Per-backend notebooks: JAX, NumPy, CuPy, PyTorch, Dask, ndonnx.

  • API reference: set_default_backend, using_backend, get_default_backend, is_*_array, BackendError, Quantity.backend, Quantity.to_*.