{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "kc-title",
   "metadata": {},
   "source": [
    "# Key Concepts\n",
    "\n",
    "The {doc}`/getting_started/quickstart` showed *what* brainmass does. This page builds the\n",
    "**mental model** so you know *why* each piece exists and how they fit together. Five ideas\n",
    "carry almost everything:\n",
    "\n",
    "```text\n",
    "   units (brainunit)        <- every quantity carries a physical unit\n",
    "        |\n",
    "        v\n",
    "   *Step model   --drive-->  Simulator   -->  trajectories (dict of arrays)\n",
    "        |                        ^\n",
    "     noise (optional)            |\n",
    "        |                     duration -> steps, monitors, transient\n",
    "        v\n",
    "   Network  (connectome -> coupling + delays, wraps a *Step)\n",
    "        |\n",
    "        v\n",
    "   Fitter  (+ objectives)  -->  best parameters (gradient / gradient-free)\n",
    "```\n",
    "\n",
    "Read top to bottom: a **model** describes one region's dynamics, a **Simulator** runs it,\n",
    "a **Network** couples many regions, and a **Fitter** tunes parameters to data. Units thread\n",
    "through all of them. Each section below is a few runnable lines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "kc-imports",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T06:27:17.672597Z",
     "iopub.status.busy": "2026-06-19T06:27:17.672434Z",
     "iopub.status.idle": "2026-06-19T06:27:22.420322Z",
     "shell.execute_reply": "2026-06-19T06:27:22.419114Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import brainmass\n",
    "import braintools\n",
    "import brainstate\n",
    "import brainunit as u\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "from brainstate.nn import Param\n",
    "\n",
    "# `dt` is global state. Setting it once lets the Network below size its delay\n",
    "# buffers at construction time; the Simulator is still given an explicit dt= too.\n",
    "brainstate.environ.set(dt=0.1 * u.ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "kc-units-md",
   "metadata": {},
   "source": [
    "## 1. Units everywhere (`brainunit`)\n",
    "\n",
    "brainmass quantities carry **physical units** via\n",
    "[`brainunit`](https://github.com/chaobrain/brainunit) (imported as `u`). A duration is\n",
    "`200 * u.ms`, a distance is `30 * u.mm`, a conduction speed is `10 * u.mm / u.ms`. Units\n",
    "are checked at composition time, so a dimensionally wrong expression fails loudly instead of\n",
    "silently producing nonsense. The integration step `dt` is itself a unit-aware quantity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "kc-units-code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T06:27:22.422998Z",
     "iopub.status.busy": "2026-06-19T06:27:22.422509Z",
     "iopub.status.idle": "2026-06-19T06:27:22.439557Z",
     "shell.execute_reply": "2026-06-19T06:27:22.438949Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "duration / dt = 2000.0 (dimensionless)\n",
      "delay = distance / speed = 3. ms\n"
     ]
    }
   ],
   "source": [
    "duration = 200.0 * u.ms\n",
    "dt = 0.1 * u.ms\n",
    "\n",
    "# duration / dt is a dimensionless number of steps -- this is exactly how the\n",
    "# Simulator turns a run length into an integer step count.\n",
    "n_steps = duration / dt\n",
    "print(\"duration / dt =\", n_steps, \"(dimensionless)\")\n",
    "\n",
    "# A speed has length/time units; distance / speed is therefore a time (a delay).\n",
    "delay = (30.0 * u.mm) / (10.0 * u.mm / u.ms)\n",
    "print(\"delay = distance / speed =\", delay)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "kc-step-md",
   "metadata": {},
   "source": [
    "## 2. The `*Step` model contract\n",
    "\n",
    "Every neural-mass model is a `*Step` class implementing **one update step** of its\n",
    "differential equations. The contract is small and uniform across all 20+ models:\n",
    "\n",
    "| Piece | What it is |\n",
    "|---|---|\n",
    "| `Model(in_size, **params)` | construct, sized for `in_size` regions; parameters broadcast to that shape |\n",
    "| `init_all_states()` | allocate / reset the hidden states (call before stepping) |\n",
    "| `update(*inputs)` | advance the state by one `dt`, applying external inputs |\n",
    "| `model.<var>.value` | read a state variable (e.g. `model.x.value`, `model.rE.value`) |\n",
    "\n",
    "You rarely call `init_all_states`/`update` by hand — the `Simulator` does — but seeing them\n",
    "once demystifies what the orchestration layer drives. `brainmass.list_models()` enumerates\n",
    "every model with its category and number of state variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "kc-step-code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T06:27:22.441740Z",
     "iopub.status.busy": "2026-06-19T06:27:22.441494Z",
     "iopub.status.idle": "2026-06-19T06:27:23.113489Z",
     "shell.execute_reply": "2026-06-19T06:27:23.112892Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "state x after one step: [0. 0. 0.]\n",
      "number of models: 20\n",
      "Hopf record: ModelInfo(name='HopfStep', category='phenomenological', n_state_vars=2, use_case='Oscillation onset, rhythm generation')\n"
     ]
    }
   ],
   "source": [
    "node = brainmass.HopfStep(in_size=3, a=0.25, w=0.3)   # 3 regions\n",
    "\n",
    "with brainstate.environ.context(dt=0.1 * u.ms):\n",
    "    node.init_all_states()          # allocate hidden states x, y\n",
    "    node.update()                   # one step\n",
    "    print(\"state x after one step:\", node.x.value)     # shape (3,)\n",
    "\n",
    "# Discover models programmatically.\n",
    "models = brainmass.list_models()\n",
    "print(\"number of models:\", len(models))\n",
    "print(\"Hopf record:\", next(m for m in models if m.name == \"HopfStep\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "kc-sim-md",
   "metadata": {},
   "source": [
    "## 3. `Simulator` — duration, monitors, transient\n",
    "\n",
    "The {class}`~brainmass.Simulator` collapses *set `dt` → init states → loop → collect* into\n",
    "one `run` call. Three knobs cover most needs:\n",
    "\n",
    "- **`duration`** — a unit-aware time; the number of steps is `duration / dt`.\n",
    "- **`monitors`** — *what* to record each step: a list of state names (`['x']`), a callable\n",
    "  `lambda m: ...` for a derived observable (returned under `'output'`), or a dict.\n",
    "- **`transient`** — a leading warm-up window (a duration or a step count) to discard, so you\n",
    "  keep only the settled dynamics.\n",
    "\n",
    "It returns a plain dict mapping each monitor name to its stacked trajectory, plus a `'ts'`\n",
    "time axis — a valid JAX pytree, safe to return through `jit`/`grad`/`vmap`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "kc-sim-code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T06:27:23.115663Z",
     "iopub.status.busy": "2026-06-19T06:27:23.115462Z",
     "iopub.status.idle": "2026-06-19T06:27:23.255756Z",
     "shell.execute_reply": "2026-06-19T06:27:23.254764Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "kept steps: 800 (1000 - 200 transient)\n",
      "keys: ['x', 'y', 'ts']\n"
     ]
    }
   ],
   "source": [
    "sim = brainmass.Simulator(node, dt=0.1 * u.ms)\n",
    "res = sim.run(\n",
    "    100.0 * u.ms,            # -> 1000 steps at dt = 0.1 ms\n",
    "    monitors=[\"x\", \"y\"],     # record two state variables\n",
    "    transient=20.0 * u.ms,   # drop the first 200 steps\n",
    ")\n",
    "print(\"kept steps:\", res[\"x\"].shape[0], \"(1000 - 200 transient)\")\n",
    "print(\"keys:\", list(res))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "kc-noise-md",
   "metadata": {},
   "source": [
    "## 4. Where noise fits\n",
    "\n",
    "Noise is a property of the **model**, not the simulator. You attach a noise process (e.g.\n",
    "{class}`~brainmass.OUProcess`) to a state component at construction; it is sized like the\n",
    "model and is sampled and added inside `update()` automatically. The `Simulator` call does\n",
    "not change — a deterministic run and a stochastic run differ only in how the model was\n",
    "built."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "kc-noise-code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T06:27:23.258415Z",
     "iopub.status.busy": "2026-06-19T06:27:23.258164Z",
     "iopub.status.idle": "2026-06-19T06:27:23.519797Z",
     "shell.execute_reply": "2026-06-19T06:27:23.518864Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stochastic run shape: (500, 3)\n"
     ]
    }
   ],
   "source": [
    "stochastic = brainmass.HopfStep(\n",
    "    in_size=3, a=0.25, w=0.3,\n",
    "    noise_x=brainmass.OUProcess(in_size=3, sigma=0.1, tau=20.0 * u.ms),\n",
    ")\n",
    "res_n = brainmass.Simulator(stochastic, dt=0.1 * u.ms).run(50.0 * u.ms, monitors=[\"x\"])\n",
    "print(\"stochastic run shape:\", res_n[\"x\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "kc-net-md",
   "metadata": {},
   "source": [
    "## 5. `Network` — connectome → coupling + delays\n",
    "\n",
    "A {class}`~brainmass.Network` turns a single `*Step` node (sized for *N* regions) into a\n",
    "coupled whole-brain model. You give it a **structural connectivity** matrix and, optionally,\n",
    "a **distance** matrix plus a conduction **speed**:\n",
    "\n",
    "- the connectivity diagonal is zeroed (no self-coupling),\n",
    "- `distance / speed` becomes per-edge conduction **delays**,\n",
    "- each step it computes a coupling current (diffusive / additive / nonlinear) and feeds it\n",
    "  back into the node as its first input.\n",
    "\n",
    "Crucially, a `Network` *is itself* a `brainstate` module with the same `init`/`update`\n",
    "contract, so the **same `Simulator`** drives it. The bundled `example_connectome` gives you\n",
    "a ready-made `weights` + `distances` pair."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "kc-net-code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T06:27:23.522344Z",
     "iopub.status.busy": "2026-06-19T06:27:23.522107Z",
     "iopub.status.idle": "2026-06-19T06:27:24.240669Z",
     "shell.execute_reply": "2026-06-19T06:27:24.239877Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8-region network output shape: (500, 8)\n"
     ]
    }
   ],
   "source": [
    "conn = brainmass.datasets.load_dataset(\"example_connectome\")\n",
    "N = conn.weights.shape[0]\n",
    "\n",
    "net = brainmass.Network(\n",
    "    brainmass.HopfStep(in_size=N, a=0.2, w=0.3),\n",
    "    conn=conn.weights,\n",
    "    distance=conn.distances,\n",
    "    speed=10.0 * u.mm / u.ms,\n",
    "    coupling=\"diffusive\",\n",
    "    coupled_var=\"x\",\n",
    "    k=0.5,\n",
    ")\n",
    "res_net = brainmass.Simulator(net, dt=0.1 * u.ms).run(\n",
    "    50.0 * u.ms, monitors=lambda m: m.node.x.value\n",
    ")\n",
    "print(f\"{N}-region network output shape:\", res_net[\"output\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "kc-fit-md",
   "metadata": {},
   "source": [
    "## 6. `Fitter` + `objectives`\n",
    "\n",
    "The {class}`~brainmass.Fitter` tunes a model's **trainable parameters** to data behind one\n",
    "`.fit()` call. Two pieces define the problem:\n",
    "\n",
    "- **trainable parameters** — wrap a value in `Param(value, fit=True)`. Only `fit=True`\n",
    "  parameters are optimised; everything else is held fixed.\n",
    "- **an objective** — a callable scoring a prediction against a target.\n",
    "  {mod}`brainmass.objectives` provides composable ones (`timeseries_rmse`, `fc_corr`,\n",
    "  `fcd_ks`, …); you can also pass a single `loss_fn(model) -> (loss, aux)`.\n",
    "\n",
    "One `backend=` switch chooses the optimiser:\n",
    "\n",
    "| backend | how it searches | when |\n",
    "|---|---|---|\n",
    "| `'grad'` (default) | **backprop through the ODE solve** | the headline path — fast, scales to many parameters |\n",
    "| `'nevergrad'` | evolutionary, gradient-free | a few scalar params, non-differentiable objectives |\n",
    "| `'scipy'` | SciPy optimisers | classic local/derivative-free methods |\n",
    "\n",
    "Below: one trainable parameter `gain`, fit with the gradient backend so a scalar output\n",
    "matches a target — the smallest possible end-to-end fit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "kc-fit-code",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T06:27:24.242922Z",
     "iopub.status.busy": "2026-06-19T06:27:24.242674Z",
     "iopub.status.idle": "2026-06-19T06:27:24.364385Z",
     "shell.execute_reply": "2026-06-19T06:27:24.363511Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FitResult(backend='grad', best_loss=1.50383e-05, n_steps=60, params=[gain])\n",
      "gain:  0.0  ->  2.200  (target 2.0)\n"
     ]
    }
   ],
   "source": [
    "class Gain(brainstate.nn.Module):\n",
    "    \"\"\"A toy 'model' whose single trainable parameter is its output.\"\"\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.gain = Param(0.0, fit=True)   # the one knob to fit\n",
    "\n",
    "    def update(self):\n",
    "        return self.gain.value()\n",
    "\n",
    "\n",
    "def predict(m):\n",
    "    out = brainmass.Simulator(m, dt=0.1 * u.ms).run(1.0 * u.ms, monitors=None)[\"output\"]\n",
    "    return jnp.mean(out)\n",
    "\n",
    "\n",
    "fitter = brainmass.Fitter(\n",
    "    Gain(),\n",
    "    braintools.optim.Adam(lr=0.2),\n",
    "    predict=predict,\n",
    "    objective=brainmass.objectives.timeseries_rmse(),  # from brainmass.objectives\n",
    ")\n",
    "result = fitter.fit(target=jnp.asarray(2.0), n_steps=60)\n",
    "print(result)\n",
    "print(f\"gain:  0.0  ->  {float(result.best_params['gain']):.3f}  (target 2.0)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "kc-recap",
   "metadata": {},
   "source": [
    "## Putting it together\n",
    "\n",
    "- A **`*Step` model** is one region's dynamics: `init_all_states` → `update` → read\n",
    "  `.value` states. Noise attaches to the model.\n",
    "- The **`Simulator`** drives any model (single node *or* `Network`) for a `duration`,\n",
    "  recording `monitors` after an optional `transient`.\n",
    "- A **`Network`** wraps a node with a connectome, deriving coupling and delays, and is\n",
    "  driven by the same `Simulator`.\n",
    "- The **`Fitter`** optimises `Param(fit=True)` knobs against an **objective**, defaulting to\n",
    "  gradient descent through the differentiable solve.\n",
    "- **Units** (`brainunit`) keep every quantity physical, from `dt` to delays.\n",
    "\n",
    "## Where to go next\n",
    "\n",
    "- {doc}`/concepts/what_is_a_neural_mass_model` — the theory behind the `*Step` models.\n",
    "- {doc}`/concepts/why_differentiable` — why gradient-based fitting changes the game.\n",
    "- {doc}`/concepts/coupling_and_delays` — the structural-connectivity and delay maths.\n",
    "- {doc}`/tutorials/01_first_simulation` — start the learning-oriented tutorial track.\n",
    "- {doc}`/howto/choose_a_model` — pick a model with `brainmass.list_models()`.\n",
    "\n",
    "## See also\n",
    "\n",
    "- {doc}`/reference/orchestration` — `Simulator` / `Network` / `Fitter`.\n",
    "- {doc}`/reference/models`, {doc}`/reference/noise`, {doc}`/reference/coupling`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
