{
  "cells": [
    {
      "metadata": {},
      "cell_type": "markdown",
      "source": [
        "# Parameter Optimization with ``Scipy``\n",
        "\n",
        "[scipy](https://docs.scipy.org/doc/scipy/tutorial/optimize.html) has provided many excellent optimization algorithms for decades. Here we show how to use them to fit a whole-brain network model to empirical functional connectivity (FC)."
      ],
      "id": "83b3643df6f32468"
    },
    {
      "metadata": {},
      "cell_type": "markdown",
      "source": [
        "\n",
        "This notebook demonstrates gradient-based (or gradient-free) parameter optimization with SciPy to fit a whole-brain network to empirical functional connectivity (FC).\n",
        "\n",
        "- Goal: tune global coupling k and noise sigma of a Wilson–Cowan network so simulated FC matches a target FC.\n",
        "- Loss: 1 - corr(FC_target, FC_model).\n",
        "- We JIT-compile the loss for speed (optional) and call a SciPy optimizer.\n"
      ],
      "id": "f9ec099c37b37957"
    },
    {
      "metadata": {},
      "cell_type": "code",
      "outputs": [],
      "execution_count": null,
      "source": [
        "import brainstate\n",
        "import brainunit as u\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "\n",
        "import brainmass\n",
        "import braintools\n"
      ],
      "id": "b91b9a5eb677583f"
    },
    {
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-09-22T13:23:06.088194Z",
          "start_time": "2025-09-22T13:23:06.081274Z"
        }
      },
      "cell_type": "code",
      "source": "brainstate.environ.set(dt=0.1 * u.ms)",
      "id": "72a0ea87f59d5d5c",
      "outputs": [],
      "execution_count": 2
    },
    {
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-09-22T13:23:08.480864Z",
          "start_time": "2025-09-22T13:23:06.093021Z"
        }
      },
      "cell_type": "code",
      "source": [
        "import os.path\n",
        "import kagglehub\n",
        "\n",
        "path = kagglehub.dataset_download(\"oujago/hcp-gw-data-samples\")\n",
        "data = braintools.file.msgpack_load(os.path.join(path, \"hcp-data-sample.msgpack\"))\n",
        "\n",
        "target_fc = [braintools.metric.functional_connectivity(x.T) for x in data['BOLDs']]\n",
        "target_fc = jnp.mean(jnp.asarray(target_fc), axis=0)\n",
        "\n"
      ],
      "id": "ff267f6bb5121ba3",
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Loading checkpoint from D:\\Data\\kagglehub\\datasets\\oujago\\hcp-gw-data-samples\\versions\\1\\hcp-data-sample.msgpack\n"
          ]
        }
      ],
      "execution_count": 3
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Data and Target FC\r\n\r\nWe download a small HCP sample via kagglehub, providing structural connectivity (Cmat) and distances (Dmat). For each BOLD time series, we compute FC and then average across scans to obtain target_fc.\r\n\r\n- Cmat: weights used for coupling.\r\n- Dmat: distances converted to delays using a signal speed.\r\n- target_fc: average empirical FC used by the loss.\r\n\r\nIf fetching fails, point to a local msgpack file, or replace with your own FC target.\r\n"
      ],
      "id": "abd94cb2be86421a"
    },
    {
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-09-22T13:23:10.099696Z",
          "start_time": "2025-09-22T13:23:10.094479Z"
        }
      },
      "cell_type": "code",
      "source": [
        "class Network(brainstate.nn.Module):\n",
        "    def __init__(self, signal_speed=2., k=1., sigma=0.01):\n",
        "        super().__init__()\n",
        "\n",
        "        conn_weight = data['Cmat'].copy()\n",
        "        np.fill_diagonal(conn_weight, 0)\n",
        "        delay_time = data['Dmat'].copy() / signal_speed\n",
        "        np.fill_diagonal(delay_time, 0)\n",
        "        indices_ = np.arange(conn_weight.shape[1])\n",
        "        indices_ = np.tile(np.expand_dims(indices_, axis=0), (conn_weight.shape[0], 1))\n",
        "\n",
        "        self.node = brainmass.WilsonCowanStep(\n",
        "            80,\n",
        "            noise_E=brainmass.OUProcess(80, sigma=sigma, init=braintools.init.ZeroInit()),\n",
        "            noise_I=brainmass.OUProcess(80, sigma=sigma, init=braintools.init.ZeroInit()),\n",
        "        )\n",
        "        self.coupling = brainmass.DiffusiveCoupling(\n",
        "            self.node.prefetch_delay('rE', (delay_time * u.ms, indices_), init=braintools.init.Uniform(0, 0.05)),\n",
        "            self.node.prefetch('rE'),\n",
        "            conn_weight,\n",
        "            k=k\n",
        "        )\n",
        "\n",
        "    def update(self):\n",
        "        current = self.coupling()\n",
        "        rE = self.node(current)\n",
        "        return rE\n",
        "\n",
        "    def step_run(self, i):\n",
        "        with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):\n",
        "            return self.update()\n",
        "\n"
      ],
      "id": "731391424938d899",
      "outputs": [],
      "execution_count": 4
    },
    {
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-09-22T13:23:10.113498Z",
          "start_time": "2025-09-22T13:23:10.106732Z"
        }
      },
      "cell_type": "code",
      "source": [
        "def simulation(k, sigma):\n",
        "    net = Network(k=k, sigma=sigma)\n",
        "    brainstate.nn.init_all_states(net)\n",
        "    indices = np.arange(0, 1e3 * u.ms // brainstate.environ.get_dt())\n",
        "    exes = brainstate.transform.for_loop(net.step_run, indices)\n",
        "    fc = braintools.metric.functional_connectivity(exes)\n",
        "    return braintools.metric.matrix_correlation(target_fc, fc)\n",
        "\n"
      ],
      "id": "d1003d6d5c98bfb8",
      "outputs": [],
      "execution_count": 5
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Model and Coupling\r\n\r\nWe simulate 80 Wilson–Cowan nodes with OU noise on E and I. Diffusive coupling is applied on rE via DiffusiveCoupling:\r\n\r\n- Global gain k scales Cmat.\r\n- Delays are Dmat / signal_speed and handled with prefetch_delay.\r\n- The module's update returns rE, used for FC computation.\r\n"
      ],
      "id": "e2b7d931ab70facc"
    },
    {
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-09-22T13:23:10.121011Z",
          "start_time": "2025-09-22T13:23:10.117504Z"
        }
      },
      "cell_type": "code",
      "source": [
        "@brainstate.transform.jit\n",
        "def loss_fn(arr):\n",
        "    k, sigma = arr\n",
        "    return 1 - simulation(k, sigma)\n",
        "\n"
      ],
      "id": "6da60654b81e48cc",
      "outputs": [],
      "execution_count": 6
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Simulation and Loss\r\n\r\nThe simulation runs the network for a short window, computes FC from excitatory activity, then returns correlation with target_fc. The loss is 1 - correlation so that lower is better. We wrap the two parameters (k, sigma) into a single array for SciPy.'s API and optionally JIT-compile for speed.\r\n"
      ],
      "id": "abad6a2f11e50f70"
    },
    {
      "metadata": {
        "ExecuteTime": {
          "end_time": "2025-09-22T13:23:34.133259Z",
          "start_time": "2025-09-22T13:23:10.127018Z"
        }
      },
      "cell_type": "code",
      "source": [
        "opt = braintools.optim.ScipyOptimizer(\n",
        "    loss_fn, bounds=[(0.5, 3.0), (0.0, 1.)], method='L-BFGS-B'\n",
        ")\n",
        "best_r = opt.minimize(n_iter=1)\n",
        "print(best_r)\n"
      ],
      "id": "b5faea0014ba586c",
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "  message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH\n",
            "  success: True\n",
            "   status: 0\n",
            "      fun: 1.0078195333480835\n",
            "        x: [ 5.000e-01  1.000e+00]\n",
            "      nit: 1\n",
            "      jac: [-9.956e-02  1.753e-02]\n",
            "     nfev: 6\n",
            "     njev: 6\n",
            " hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>\n"
          ]
        }
      ],
      "execution_count": 7
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## SciPy Optimizer Setup\r\n\r\nWe use braintools.optim.ScipyOptimizer as a thin wrapper around scipy.optimize.minimize. Bounds and method (e.g., L-BFGS-B, Nelder-Mead) can be customized. Increase n_iter or switch methods for robustness.\r\n\r\nTips:\r\n- Start with short simulations for fast iterations, then refine.\r\n- Consider multiple random restarts.\r\n- If gradients are unreliable (stochastic noise), prefer derivative-free methods (Nelder-Mead, Powell).\r\n"
      ],
      "id": "63a7aa45fd613f08"
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Notes\r\n\r\n- Ensure JAX backend is configured to benefit from jit.\r\n- Set seeds or fix noise initial states for reproducibility when comparing runs.\r\n- You can extend the loss to multi-objective (e.g., also matching power spectra).\r\n"
      ],
      "id": "89c22fd49cb858ef"
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 2
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "2.7.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}