{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cf701c7b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T09:26:22.877649Z",
     "iopub.status.busy": "2026-06-19T09:26:22.877504Z",
     "iopub.status.idle": "2026-06-19T09:26:27.367934Z",
     "shell.execute_reply": "2026-06-19T09:26:27.366849Z"
    },
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import warnings\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import brainstate\n",
    "import braintools\n",
    "import brainunit as u\n",
    "import brainmass\n",
    "from brainstate.nn import Param\n",
    "\n",
    "brainstate.environ.set(dt=0.1 * u.ms)\n",
    "brainstate.random.seed(0)\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "629ec39f",
   "metadata": {},
   "source": [
    "# Data-Driven Modeling\n",
    "\n",
    "**This is the headline of brainmass.** Other whole-brain toolkits simulate forward and\n",
    "fit parameters by black-box search. brainmass is built on JAX, so the *entire* pipeline\n",
    "— parameters → ODE solve → coupling → forward model → signal — is **differentiable**.\n",
    "You can backpropagate through the simulation and obtain the exact gradient of a data-fit\n",
    "loss with respect to every parameter at once.\n",
    "\n",
    "That single property reshapes how you do science with neural-mass models. It turns\n",
    "parameter *fitting* into gradient descent instead of grid or evolutionary search; it\n",
    "makes high-dimensional parameter **fields** (a value per region) tractable where grid\n",
    "search is hopeless; and it lets you *train* neural-mass networks on cognitive tasks the\n",
    "way you train any other differentiable model. **Data-driven modeling** — constructing,\n",
    "fitting, and training models against measured data — is the center of gravity of the\n",
    "whole library, and this page is the curated path through it.\n",
    "\n",
    "The three verbs of the data-driven workflow:\n",
    "\n",
    "| Verb | What it means | In-package home |\n",
    "| --- | --- | --- |\n",
    "| **Construct** | Build a whole-brain model from a connectome (structure, delays, coupling). | {class}`brainmass.Network` |\n",
    "| **Fit** | Recover parameters from data by gradient or gradient-free search. | {class}`brainmass.Fitter` |\n",
    "| **Train** | Drive a network's parameters from a task / dataset of `(inputs, targets)`. | hand-written loop today; a deferred `Trainer` tomorrow (see the {doc}`roadmap`) |\n",
    "\n",
    "Throughout, models map activity to observable signals (BOLD / EEG / MEG) through\n",
    "in-package forward models, so the loss is computed against the *same modality* as the\n",
    "data you are fitting.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ed869b3",
   "metadata": {},
   "source": [
    "## Why brainmass for data-driven modeling\n",
    "\n",
    "brainmass shares the neural-mass / whole-brain space with\n",
    "[The Virtual Brain](https://www.thevirtualbrain.org/) (TVB) and\n",
    "[neurolib](https://github.com/neurolib-dev/neurolib). Both are mature, widely-used, and\n",
    "excellent — but both are **NumPy / Numba** based: they run forward simulations and fit\n",
    "with grid or evolutionary search. Neither has an autodiff / JAX core, so neither can\n",
    "backpropagate through the solve or run natively on GPU/TPU. The table below mirrors the\n",
    "landing page and is deliberately **conservative** (\"Partial\" = exists in some narrower\n",
    "form); consult each project for its current state.\n",
    "\n",
    "| Capability | brainmass | The Virtual Brain | neurolib |\n",
    "| --- | --- | --- | --- |\n",
    "| Differentiable / gradient-based fitting (backprop through the solve) | Yes | No | No |\n",
    "| JAX backend with GPU / TPU acceleration | Yes | No | No |\n",
    "| In-package orchestration & fitting (`Simulator` / `Network` / `Fitter`) | Yes | Partial | Partial |\n",
    "| Unit-safe quantities (dimensional analysis) | Yes | No | No |\n",
    "| Next-generation / exact mean-field models (e.g. Montbrió-Pazó-Roxin, Coombes-Byrne) | Yes | Partial | Partial |\n",
    "| In-package BOLD + EEG/MEG forward models | Yes | Yes | Partial |\n",
    "\n",
    "The differentiable row is the one that defines the pillar. The deeper rationale — the\n",
    "maths of backprop-through-the-solve, when *not* to use gradients, and the application\n",
    "scenarios it unlocks — lives in {doc}`/concepts/why_differentiable`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afcb8fae",
   "metadata": {},
   "source": [
    "## A 30-second teaser\n",
    "\n",
    "Here is the whole idea in one fit. We take a {class}`brainmass.HopfStep` node whose\n",
    "settled limit-cycle amplitude depends on its bifurcation parameter `a`, mark `a`\n",
    "trainable with `Param(..., fit=True)`, and let {class}`brainmass.Fitter` recover the\n",
    "`a` that reproduces a target amplitude — by **gradient descent through the simulation**.\n",
    "No finite differences, no grid: the optimiser is handed the exact gradient.\n",
    "\n",
    "(For oscillatory models, fit a **scalar summary** like amplitude, FC, or a spectral\n",
    "peak — never the raw phase-sensitive time series; see\n",
    "{doc}`/tutorials/06_fitting_with_gradients`.)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "04c16dc9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T09:26:27.370816Z",
     "iopub.status.busy": "2026-06-19T09:26:27.370340Z",
     "iopub.status.idle": "2026-06-19T09:26:28.046579Z",
     "shell.execute_reply": "2026-06-19T09:26:28.045728Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "true a       = 1.5\n",
      "recovered a  = 1.500\n",
      "final loss   = 2.18e-11\n"
     ]
    }
   ],
   "source": [
    "A_TRUE = 1.5  # the (unknown to the fitter) parameter we want to recover\n",
    "\n",
    "def settled_amplitude(model):\n",
    "    \"\"\"RMS limit-cycle amplitude after the transient (a smooth scalar summary).\"\"\"\n",
    "    sim = brainmass.Simulator(model, dt=0.1 * u.ms)\n",
    "    res = sim.run(400 * u.ms, monitors=['x'], transient=200 * u.ms)\n",
    "    x = u.get_magnitude(res['x'])\n",
    "    return (2.0 ** 0.5) * (x ** 2).mean() ** 0.5\n",
    "\n",
    "def make_node(a):\n",
    "    # kick off the unstable fixed point so a > 0 actually oscillates\n",
    "    return brainmass.HopfStep(in_size=1, a=a, w=0.3,\n",
    "                              init_x=braintools.init.Constant(0.5))\n",
    "\n",
    "# The \"data\": the target amplitude produced by the true model.\n",
    "target_amp = float(settled_amplitude(make_node(A_TRUE)))\n",
    "\n",
    "# Fit a fresh node (a far from the truth) back to that target by gradient descent.\n",
    "node = make_node(Param(0.1, fit=True))\n",
    "\n",
    "def loss_fn(model):\n",
    "    amp = settled_amplitude(model)\n",
    "    return (amp - target_amp) ** 2, amp\n",
    "\n",
    "fitter = brainmass.Fitter(node, braintools.optim.Adam(lr=0.1), loss_fn=loss_fn)\n",
    "result = fitter.fit(n_steps=120)\n",
    "\n",
    "print(f\"true a       = {A_TRUE}\")\n",
    "print(f\"recovered a  = {result.best_params['a']:.3f}\")\n",
    "print(f\"final loss   = {result.best_loss:.2e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c54b89be",
   "metadata": {},
   "source": [
    "The optimiser walks `a` straight to the value that matches the target amplitude in a\n",
    "few dozen steps. Swap the optimiser line for `backend='nevergrad'` or `backend='scipy'`\n",
    "and the *same* objective is minimised by gradient-free search — the guided path below\n",
    "walks through all three.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "076b326a",
   "metadata": {},
   "source": [
    "## A guided path\n",
    "\n",
    "The data-driven workflow is spread across the Diátaxis quadrants — this hub curates a\n",
    "reading order through them rather than duplicating their content. Pick the entry point\n",
    "that matches where you are.\n",
    "\n",
    "### 1 · Understand the idea\n",
    "\n",
    "- {doc}`/concepts/why_differentiable` — the conceptual core: backprop-through-the-solve\n",
    "  maths, the scaling argument vs grid / evolutionary search, when *not* to reach for\n",
    "  gradients, and the full brainmass-vs-TVB-vs-neurolib comparison.\n",
    "- {doc}`/concepts/what_is_a_neural_mass_model` — what these models are, if you are new\n",
    "  to the mean-field idea.\n",
    "\n",
    "### 2 · Learn it hands-on (tutorials)\n",
    "\n",
    "Work these in order — each is a runnable notebook that builds on the last:\n",
    "\n",
    "- {doc}`/tutorials/06_fitting_with_gradients` — **the centerpiece.** Fit a parameter by\n",
    "  backprop through the `Simulator`; see autodiff match a finite-difference check, then\n",
    "  fit a network coupling against functional connectivity.\n",
    "- {doc}`/tutorials/07_gradient_free_fitting` — the *same* objective with Nevergrad and\n",
    "  SciPy backends, head-to-head against gradients. When the loss is non-differentiable or\n",
    "  noisy, this is your tool.\n",
    "- {doc}`/tutorials/08_training_on_tasks` — **training**, not fitting: drive a\n",
    "  `HORNSeqNetwork`'s parameters from a `(inputs, targets)` task with a minibatched,\n",
    "  epoched loop. This is the loop the deferred `Trainer` will wrap.\n",
    "\n",
    "### 3 · Apply it to a recipe (how-to)\n",
    "\n",
    "Task-focused recipes for when you already know the shape of your problem:\n",
    "\n",
    "- {doc}`/howto/custom_objective` — compose `brainmass.objectives` or write your own\n",
    "  `(prediction, target) -> scalar` loss and plug it into the `Fitter`.\n",
    "- {doc}`/howto/parameter_sweeps` — `vmap` a model over a grid of parameter values to map\n",
    "  a loss landscape before (or instead of) fitting.\n",
    "- {doc}`/howto/analyze_results` — turn a fitted simulation into FC / FCD / spectra to\n",
    "  validate it against data.\n",
    "\n",
    "### 4 · See it end-to-end (case studies)\n",
    "\n",
    "Complete application stories, from connectome to fitted result:\n",
    "\n",
    "- {doc}`/gallery/case_studies/eeg_fitting` — fit a Jansen-Rit EEG column to a target\n",
    "  oscillation, gradient *and* gradient-free, side by side.\n",
    "- {doc}`/gallery/case_studies/horn_cognitive_task` — train a neural-mass network from\n",
    "  chance to 100% accuracy on a cognitive task.\n",
    "\n",
    "### 5 · Extend it (developer)\n",
    "\n",
    "- {doc}`/developer/building_a_data_driven_workflow` — **the extension contract.** Expose\n",
    "  trainable parameters on a custom model, write a composable objective, fit it both ways\n",
    "  (high-level `Fitter` and a hand-written `grad` loop), and batch the search. This is the\n",
    "  stable contract the deferred `Trainer` and model-discovery tooling build *against* —\n",
    "  read it before you write new data-driven machinery.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "523f9c09",
   "metadata": {},
   "source": [
    "## Where this pillar grows next\n",
    "\n",
    "The data-driven pillar has reserved homes for its growth areas — model discovery /\n",
    "system identification, a task-shaped **`Trainer`** (distinct from the target-fitting\n",
    "`Fitter`), and simulation-based / amortized inference. They are **named and given a home\n",
    "now, not built in goal-13**, each with a documented contract pointing back to\n",
    "{doc}`/developer/building_a_data_driven_workflow`. See the roadmap:\n",
    "\n",
    "```{toctree}\n",
    ":maxdepth: 1\n",
    "\n",
    "roadmap\n",
    "```\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
