Installation#

brainmass is a JAX-native library for differentiable neural-mass modelling. This page covers installing it for CPU, GPU (CUDA 12 / 13) and TPU, the optional plotting extra, and a one-line check that the install works.

Pick your accelerator#

Install the package with the extra that matches your hardware. CPU is the simplest and is all you need to follow the Quickstart.

pip install -U brainmass[cpu]

Works everywhere (Linux, macOS incl. Apple Silicon, Windows). Recommended for getting started and for the notebooks in these docs.

pip install -U brainmass[cuda12]

For NVIDIA GPUs with a CUDA 12.x runtime. The matching jaxlib CUDA wheels are pulled in automatically.

pip install -U brainmass[cuda13]

For NVIDIA GPUs with a CUDA 13.x runtime.

pip install -U brainmass[tpu]

For Google Cloud TPU VMs.

The [viz] plotting extra#

brainmass has no hard dependency on matplotlibimport brainmass works without it. The brainmass.viz helpers (used throughout the tutorials and the gallery) import matplotlib lazily, so install the [viz] extra to enable them:

pip install -U "brainmass[cpu,viz]"     # CPU + plotting helpers

Combine extras freely, e.g. brainmass[cuda12,viz]. If you call a brainmass.viz.* function without matplotlib installed you get a clear error pointing you back to this extra.

Verify the install#

Check the version and run one tiny simulation through the high-level Simulator — if this prints a version and a trajectory shape, you are ready to go.

import brainmass
import brainunit as u

print(f"brainmass version: {brainmass.__version__}")

# One step-model, driven for 5 ms through the Simulator.
node = brainmass.HopfStep(in_size=4, a=0.25)
res = brainmass.Simulator(node, dt=0.1 * u.ms).run(5.0 * u.ms, monitors=["x"])
print("trajectory shape (steps, regions):", res["x"].shape)
print("install OK")
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
brainmass version: 0.0.6
trajectory shape (steps, regions): (50, 4)
install OK

To confirm which accelerator JAX is using:

import jax

print(jax.devices())  # CpuDevice on CPU; CudaDevice / TpuDevice with an accelerator
[CpuDevice(id=0)]

The BrainX ecosystem#

brainmass is part of the BrainX ecosystem and builds directly on three of its libraries — they are installed automatically as dependencies:

Package

Role in brainmass

brainstate

State management & the transform primitives (jit, for_loop, grad) the Simulator/Fitter are built on

brainunit

Unit-aware arrays (u.ms, u.mm, …) so quantities stay physically meaningful

braintools

Initializers, optimizers (braintools.optim) and metrics (braintools.metric) surfaced by the Fitter and viz

You rarely import them all by hand, but knowing the split helps when reading error messages.

Install from source (development)#

git clone https://github.com/chaobrain/brainmass.git
cd brainmass
pip install -e ".[cpu,viz,dev,doc]"   # editable install with test + doc tooling

Next steps#

  • Quickstart — your first simulation, plot, network and fit in ~10 minutes.

  • Key Concepts — the mental model behind *Step models, Simulator, Network and Fitter.

  • Choose a Model — browse the model catalogue with brainmass.list_models().

See also#