Key Concepts#

The Quickstart showed what brainmass does. This page builds the mental model so you know why each piece exists and how they fit together. Five ideas carry almost everything:

   units (brainunit)        <- every quantity carries a physical unit
        |
        v
   *Step model   --drive-->  Simulator   -->  trajectories (dict of arrays)
        |                        ^
     noise (optional)            |
        |                     duration -> steps, monitors, transient
        v
   Network  (connectome -> coupling + delays, wraps a *Step)
        |
        v
   Fitter  (+ objectives)  -->  best parameters (gradient / gradient-free)

Read top to bottom: a model describes one region’s dynamics, a Simulator runs it, a Network couples many regions, and a Fitter tunes parameters to data. Units thread through all of them. Each section below is a few runnable lines.

import brainmass
import braintools
import brainstate
import brainunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.nn import Param

# `dt` is global state. Setting it once lets the Network below size its delay
# buffers at construction time; the Simulator is still given an explicit dt= too.
brainstate.environ.set(dt=0.1 * u.ms)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

1. Units everywhere (brainunit)#

brainmass quantities carry physical units via brainunit (imported as u). A duration is 200 * u.ms, a distance is 30 * u.mm, a conduction speed is 10 * u.mm / u.ms. Units are checked at composition time, so a dimensionally wrong expression fails loudly instead of silently producing nonsense. The integration step dt is itself a unit-aware quantity.

duration = 200.0 * u.ms
dt = 0.1 * u.ms

# duration / dt is a dimensionless number of steps -- this is exactly how the
# Simulator turns a run length into an integer step count.
n_steps = duration / dt
print("duration / dt =", n_steps, "(dimensionless)")

# A speed has length/time units; distance / speed is therefore a time (a delay).
delay = (30.0 * u.mm) / (10.0 * u.mm / u.ms)
print("delay = distance / speed =", delay)
duration / dt = 2000.0 (dimensionless)
delay = distance / speed = 3. ms

2. The *Step model contract#

Every neural-mass model is a *Step class implementing one update step of its differential equations. The contract is small and uniform across all 20+ models:

Piece

What it is

Model(in_size, **params)

construct, sized for in_size regions; parameters broadcast to that shape

init_all_states()

allocate / reset the hidden states (call before stepping)

update(*inputs)

advance the state by one dt, applying external inputs

model.<var>.value

read a state variable (e.g. model.x.value, model.rE.value)

You rarely call init_all_states/update by hand — the Simulator does — but seeing them once demystifies what the orchestration layer drives. brainmass.list_models() enumerates every model with its category and number of state variables.

node = brainmass.HopfStep(in_size=3, a=0.25, w=0.3)   # 3 regions

with brainstate.environ.context(dt=0.1 * u.ms):
    node.init_all_states()          # allocate hidden states x, y
    node.update()                   # one step
    print("state x after one step:", node.x.value)     # shape (3,)

# Discover models programmatically.
models = brainmass.list_models()
print("number of models:", len(models))
print("Hopf record:", next(m for m in models if m.name == "HopfStep"))
state x after one step: [0. 0. 0.]
number of models: 20
Hopf record: ModelInfo(name='HopfStep', category='phenomenological', n_state_vars=2, use_case='Oscillation onset, rhythm generation')

3. Simulator — duration, monitors, transient#

The Simulator collapses set dt → init states → loop → collect into one run call. Three knobs cover most needs:

  • duration — a unit-aware time; the number of steps is duration / dt.

  • monitorswhat to record each step: a list of state names (['x']), a callable lambda m: ... for a derived observable (returned under 'output'), or a dict.

  • transient — a leading warm-up window (a duration or a step count) to discard, so you keep only the settled dynamics.

It returns a plain dict mapping each monitor name to its stacked trajectory, plus a 'ts' time axis — a valid JAX pytree, safe to return through jit/grad/vmap.

sim = brainmass.Simulator(node, dt=0.1 * u.ms)
res = sim.run(
    100.0 * u.ms,            # -> 1000 steps at dt = 0.1 ms
    monitors=["x", "y"],     # record two state variables
    transient=20.0 * u.ms,   # drop the first 200 steps
)
print("kept steps:", res["x"].shape[0], "(1000 - 200 transient)")
print("keys:", list(res))
kept steps: 800 (1000 - 200 transient)
keys: ['x', 'y', 'ts']

4. Where noise fits#

Noise is a property of the model, not the simulator. You attach a noise process (e.g. OUProcess) to a state component at construction; it is sized like the model and is sampled and added inside update() automatically. The Simulator call does not change — a deterministic run and a stochastic run differ only in how the model was built.

stochastic = brainmass.HopfStep(
    in_size=3, a=0.25, w=0.3,
    noise_x=brainmass.OUProcess(in_size=3, sigma=0.1, tau=20.0 * u.ms),
)
res_n = brainmass.Simulator(stochastic, dt=0.1 * u.ms).run(50.0 * u.ms, monitors=["x"])
print("stochastic run shape:", res_n["x"].shape)
stochastic run shape: (500, 3)

5. Network — connectome → coupling + delays#

A Network turns a single *Step node (sized for N regions) into a coupled whole-brain model. You give it a structural connectivity matrix and, optionally, a distance matrix plus a conduction speed:

  • the connectivity diagonal is zeroed (no self-coupling),

  • distance / speed becomes per-edge conduction delays,

  • each step it computes a coupling current (diffusive / additive / nonlinear) and feeds it back into the node as its first input.

Crucially, a Network is itself a brainstate module with the same init/update contract, so the same Simulator drives it. The bundled example_connectome gives you a ready-made weights + distances pair.

conn = brainmass.datasets.load_dataset("example_connectome")
N = conn.weights.shape[0]

net = brainmass.Network(
    brainmass.HopfStep(in_size=N, a=0.2, w=0.3),
    conn=conn.weights,
    distance=conn.distances,
    speed=10.0 * u.mm / u.ms,
    coupling="diffusive",
    coupled_var="x",
    k=0.5,
)
res_net = brainmass.Simulator(net, dt=0.1 * u.ms).run(
    50.0 * u.ms, monitors=lambda m: m.node.x.value
)
print(f"{N}-region network output shape:", res_net["output"].shape)
8-region network output shape: (500, 8)

6. Fitter + objectives#

The Fitter tunes a model’s trainable parameters to data behind one .fit() call. Two pieces define the problem:

  • trainable parameters — wrap a value in Param(value, fit=True). Only fit=True parameters are optimised; everything else is held fixed.

  • an objective — a callable scoring a prediction against a target. brainmass.objectives provides composable ones (timeseries_rmse, fc_corr, fcd_ks, …); you can also pass a single loss_fn(model) -> (loss, aux).

One backend= switch chooses the optimiser:

backend

how it searches

when

'grad' (default)

backprop through the ODE solve

the headline path — fast, scales to many parameters

'nevergrad'

evolutionary, gradient-free

a few scalar params, non-differentiable objectives

'scipy'

SciPy optimisers

classic local/derivative-free methods

Below: one trainable parameter gain, fit with the gradient backend so a scalar output matches a target — the smallest possible end-to-end fit.

class Gain(brainstate.nn.Module):
    """A toy 'model' whose single trainable parameter is its output."""
    def __init__(self):
        super().__init__()
        self.gain = Param(0.0, fit=True)   # the one knob to fit

    def update(self):
        return self.gain.value()


def predict(m):
    out = brainmass.Simulator(m, dt=0.1 * u.ms).run(1.0 * u.ms, monitors=None)["output"]
    return jnp.mean(out)


fitter = brainmass.Fitter(
    Gain(),
    braintools.optim.Adam(lr=0.2),
    predict=predict,
    objective=brainmass.objectives.timeseries_rmse(),  # from brainmass.objectives
)
result = fitter.fit(target=jnp.asarray(2.0), n_steps=60)
print(result)
print(f"gain:  0.0  ->  {float(result.best_params['gain']):.3f}  (target 2.0)")
FitResult(backend='grad', best_loss=1.50383e-05, n_steps=60, params=[gain])
gain:  0.0  ->  2.200  (target 2.0)

Putting it together#

  • A *Step model is one region’s dynamics: init_all_statesupdate → read .value states. Noise attaches to the model.

  • The Simulator drives any model (single node or Network) for a duration, recording monitors after an optional transient.

  • A Network wraps a node with a connectome, deriving coupling and delays, and is driven by the same Simulator.

  • The Fitter optimises Param(fit=True) knobs against an objective, defaulting to gradient descent through the differentiable solve.

  • Units (brainunit) keep every quantity physical, from dt to delays.

Where to go next#

See also#