{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f59daf5a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T09:16:30.204103Z",
     "iopub.status.busy": "2026-06-19T09:16:30.203864Z",
     "iopub.status.idle": "2026-06-19T09:16:34.873324Z",
     "shell.execute_reply": "2026-06-19T09:16:34.872392Z"
    },
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "import brainmass\n",
    "import brainstate\n",
    "import braintools\n",
    "import brainunit as u\n",
    "from brainstate.nn import Param\n",
    "brainstate.random.seed(0)\n",
    "brainstate.environ.set(dt=0.1 * u.ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8de84371",
   "metadata": {},
   "source": [
    "# Creating an Objective\n",
    "\n",
    "An **objective** scores how well a simulated trajectory matches data. `brainmass`\n",
    "ships a toolkit in `brainmass.objectives` (`timeseries_rmse`, `fc_corr`, `fcd_*`,\n",
    "...). This guide shows how to write your *own* objective so it behaves exactly like\n",
    "the built-ins: composable, unit-aware, and usable with `Fitter` across **all three**\n",
    "optimizer backends (gradient, Nevergrad, SciPy).\n",
    "\n",
    "To merely *combine* the existing objectives, see {doc}`/howto/custom_objective`.\n",
    "This guide is the authoring contract."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b9dc716",
   "metadata": {},
   "source": [
    "## The contract\n",
    "\n",
    "An objective is a **builder**: a function that takes configuration and returns a\n",
    "small `callable(prediction, target) -> scalar`. The returned callable must be:\n",
    "\n",
    "- **pure and traced-array-friendly** -- built from `jax.numpy` / `brainunit` ops so\n",
    "  it survives `jit`, `grad`, and `vmap` (the three backends each need a different\n",
    "  one of these).\n",
    "- **unit-aware** -- strip units with `brainunit.get_magnitude` where the metric is\n",
    "  scale-invariant (correlations, cosine, a variance ratio); *keep* them on a\n",
    "  difference you want unit-checked (subtracting `mV` from `Hz` should raise).\n",
    "- **a single scalar** -- the optimizers minimise a scalar.\n",
    "\n",
    "A `prediction` / `target` is a `(time, regions)` trajectory -- the natural output\n",
    "of `Simulator.run`. By convention a builder takes `as_loss=` so the same metric can\n",
    "be a score to *maximise* or `1 - score` (or its negative) to *minimise*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "21675ea0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T09:16:34.875893Z",
     "iopub.status.busy": "2026-06-19T09:16:34.875396Z",
     "iopub.status.idle": "2026-06-19T09:16:35.042925Z",
     "shell.execute_reply": "2026-06-19T09:16:35.042241Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "identity loss (mV) : 0.0\n",
      "identity loss (Hz) : 0.0\n"
     ]
    }
   ],
   "source": [
    "def variance_match(as_loss=True):\n",
    "    \"\"\"Match the overall temporal variance of two signals.\n",
    "\n",
    "    A scale-sensitive summary: invariant in *units* (we strip them)\n",
    "    but sensitive to amplitude. Returns a builder, like brainmass.objectives.*\n",
    "    \"\"\"\n",
    "    def objective(prediction, target):\n",
    "        var_p = jnp.var(u.get_magnitude(prediction))\n",
    "        var_t = jnp.var(u.get_magnitude(target))\n",
    "        d = (var_p - var_t) ** 2\n",
    "        return d if as_loss else -d\n",
    "    return objective\n",
    "\n",
    "\n",
    "# It is unit-aware: mV and Hz inputs both work, identity gives 0.\n",
    "loss = variance_match()\n",
    "x_mV = jnp.ones((50, 3)) * u.mV\n",
    "print('identity loss (mV) :', float(loss(x_mV, x_mV)))\n",
    "x_Hz = (jnp.ones((50, 3)) * 2.0) * u.Hz\n",
    "print('identity loss (Hz) :', float(loss(x_Hz, x_Hz)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fba13a62",
   "metadata": {},
   "source": [
    "## It composes with the built-ins\n",
    "\n",
    "Because a custom objective has the exact `(prediction, target) -> scalar` shape,\n",
    "`brainmass.objectives.combine` mixes it with built-ins as a weighted sum -- a\n",
    "common pattern when fitting to several features at once (e.g. FC correlation *and*\n",
    "an amplitude term)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ea84305f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T09:16:35.045260Z",
     "iopub.status.busy": "2026-06-19T09:16:35.044972Z",
     "iopub.status.idle": "2026-06-19T09:16:36.079896Z",
     "shell.execute_reply": "2026-06-19T09:16:36.078185Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "combined loss, identity : 0.0\n",
      "combined loss, perturbed: 0.782201\n"
     ]
    }
   ],
   "source": [
    "from brainmass import objectives\n",
    "\n",
    "mixed = objectives.combine(\n",
    "    (1.0, objectives.fc_corr(as_loss=True)),   # match functional connectivity\n",
    "    (0.5, variance_match(as_loss=True)),       # ... and overall amplitude\n",
    ")\n",
    "rng = np.random.default_rng(0)\n",
    "a = jnp.asarray(rng.standard_normal((200, 4)))\n",
    "print('combined loss, identity :', round(float(mixed(a, a)), 6))   # both terms 0\n",
    "b = a * 1.5 + 0.2\n",
    "print('combined loss, perturbed:', round(float(mixed(a, b)), 6))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9245fd47",
   "metadata": {},
   "source": [
    "## It works across all three `Fitter` backends\n",
    "\n",
    "The payoff of the contract: *write the objective once, swap the backend*. The\n",
    "objective is the same callable for all three; only the optimizer argument and the\n",
    "model's parameter bounds differ.\n",
    "\n",
    "- **`grad`** differentiates through the objective -- needs it to be a smooth\n",
    "  `jax` function (ours is).\n",
    "- **`nevergrad`** / **`scipy`** are derivative-free and search a bounded box. They\n",
    "  derive that box from the trainable `Param`'s transform, so the fitted parameter\n",
    "  needs a *finite* transform interval -- `SigmoidT(lower, upper)` gives one.\n",
    "\n",
    "We fit the Hopf bifurcation parameter `a` so its settled limit-cycle **variance**\n",
    "matches a target generated at `a* = 1.0`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a145e59f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T09:16:36.082135Z",
     "iopub.status.busy": "2026-06-19T09:16:36.081984Z",
     "iopub.status.idle": "2026-06-19T09:16:36.433431Z",
     "shell.execute_reply": "2026-06-19T09:16:36.432471Z"
    }
   },
   "outputs": [],
   "source": [
    "from brainstate.nn import SigmoidT\n",
    "\n",
    "def make_model(a0=0.3):\n",
    "    # SigmoidT(0.1, 2.0) -> a bounded, trainable a in [0.1, 2.1]; the kick avoids\n",
    "    # the unstable fixed point so the limit cycle actually has amplitude.\n",
    "    return brainmass.HopfStep(\n",
    "        3, a=Param(a0, t=SigmoidT(0.1, 2.0)), w=0.3,\n",
    "        init_x=braintools.init.Constant(0.5),\n",
    "    )\n",
    "\n",
    "def predict(m):\n",
    "    sim = brainmass.Simulator(m, dt=0.1 * u.ms)\n",
    "    return sim.run(200. * u.ms, monitors=['x'], transient=50 * u.ms)['x']\n",
    "\n",
    "target = predict(make_model(1.0))      # ground truth at a* = 1.0\n",
    "objective = variance_match(as_loss=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a95b52b1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T09:16:36.435408Z",
     "iopub.status.busy": "2026-06-19T09:16:36.435264Z",
     "iopub.status.idle": "2026-06-19T09:16:50.651104Z",
     "shell.execute_reply": "2026-06-19T09:16:50.650340Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     grad:  a = 0.9495   best_loss = 1.135e-03\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/chaoming/miniconda3/lib/python3.13/site-packages/braintools/optim/_scipy_optimizer.py:284: RuntimeWarning: Method Nelder-Mead does not use gradient information (jac).\n",
      "  results = minimize(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    scipy:  a = 1.0000   best_loss = 5.288e-11\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nevergrad:  a = 1.0042   best_loss = 4.530e-06\n",
      "\n",
      "true a* = 1.0; all three recovered it from the SAME objective.\n"
     ]
    }
   ],
   "source": [
    "backends = [\n",
    "    ('grad',      braintools.optim.Adam(lr=0.05), 40),\n",
    "    ('scipy',     {'method': 'Nelder-Mead'},      4),\n",
    "    ('nevergrad', {'method': 'DE', 'n_sample': 6}, 4),\n",
    "]\n",
    "\n",
    "results = {}\n",
    "for backend, opt, n in backends:\n",
    "    m = make_model(0.3)\n",
    "    fitter = brainmass.Fitter(m, opt, objective=objective,\n",
    "                              predict=predict, backend=backend)\n",
    "    res = fitter.fit(target=target, n_steps=n)\n",
    "    a_fit = float(list(res.best_params.values())[0])\n",
    "    results[backend] = a_fit\n",
    "    print(f'{backend:>9s}:  a = {a_fit:.4f}   best_loss = {res.best_loss:.3e}')\n",
    "\n",
    "print('\\ntrue a* = 1.0; all three recovered it from the SAME objective.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f688043",
   "metadata": {},
   "source": [
    "All three backends drive `a` from `0.3` to `~1.0` using one objective callable.\n",
    "That is the whole point: the objective is decoupled from how it is optimised."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32d82910",
   "metadata": {},
   "source": [
    "## Notes for a gradient-friendly objective\n",
    "\n",
    "- **Smoothness matters for `grad`.** A `max`/`argmax` (like a KS statistic) is\n",
    "  non-smooth -- usable for evaluation but a poor *gradient* loss. Prefer a smooth\n",
    "  surrogate (an integral / Wasserstein-style distance) when the objective drives\n",
    "  the gradient backend. The built-in `fcd_ks` vs `fcd_wasserstein` pair is exactly\n",
    "  this trade-off.\n",
    "- **Fit a well-conditioned summary**, not a phase-degenerate raw oscillatory trace\n",
    "  (see {doc}`building_a_data_driven_workflow`).\n",
    "- **Reuse `braintools.metric`** rather than re-implementing metrics; the built-in\n",
    "  objectives are thin wrappers over it (`functional_connectivity`,\n",
    "  `matrix_correlation`, `power_spectral_density`, ...).\n",
    "\n",
    "## See Also\n",
    "\n",
    "- {doc}`/howto/custom_objective` -- combining and applying objectives.\n",
    "- {doc}`building_a_data_driven_workflow` -- the full fitting playbook.\n",
    "- {doc}`/reference/orchestration` -- the `objectives` / `Fitter` API reference."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
