{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# Parameter Optimization with ``Nevergrad``\n",
    "\n",
    "[nevergrad](https://github.com/facebookresearch/nevergrad) is a Python toolbox for performing gradient-free optimization."
   ],
   "id": "83b3643df6f32468"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " This notebook demonstrates gradient-free parameter optimization of a whole-brain network using Nevergrad.\n",
    " \n",
    " - Objective: tune the global coupling `k` of a Wilson–Cowan network so that simulated functional connectivity (FC) matches an empirical target FC.\n",
    " - Approach: define a similarity-based loss `1 - corr(FC_target, FC_model)`, then minimize it with a Nevergrad optimizer.\n",
    " - Acceleration: JIT-compile and `vmap` to evaluate multiple candidates in parallel.\n"
   ],
   "id": "fee5974317af71f2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import brainstate\n",
    "import braintools\n",
    "import brainunit as u\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "import brainmass"
   ],
   "id": "b692973106adbc1a",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "brainstate.environ.set(dt=0.1 * u.ms)",
   "id": "72a0ea87f59d5d5c",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import os.path\n",
    "import kagglehub\n",
    "\n",
    "path = kagglehub.dataset_download(\"oujago/hcp-gw-data-samples\")\n",
    "data = braintools.file.msgpack_load(os.path.join(path, \"hcp-data-sample.msgpack\"))\n",
    "\n",
    "target_fc = [braintools.metric.functional_connectivity(x.T) for x in data['BOLDs']]\n",
    "target_fc = jnp.mean(jnp.asarray(target_fc), axis=0)"
   ],
   "id": "ff267f6bb5121ba3",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data and Target FC\r\n \r\n We load a sample HCP dataset via `kagglehub` which provides structural (`Cmat`) and distance (`Dmat`) connectivity. For each BOLD time series, we compute FC and then average across scans.\r\n \r\n - `Cmat`: connection weights (used for coupling).\r\n - `Dmat`: inter-node distances (used to build delays given a signal speed).\r\n - `target_fc`: mean empirical FC used as optimization target.\r\n \r\n If dataset access fails, provide a local path and load the same `msgpack` file, or substitute your own target FC.\r\n"
   ],
   "id": "1c9f516f0725af27"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "class Network(brainstate.nn.Module):\n",
    "    def __init__(self, signal_speed=2., k=1., sigma=0.01):\n",
    "        super().__init__()\n",
    "\n",
    "        conn_weight = data['Cmat'].copy()\n",
    "        np.fill_diagonal(conn_weight, 0)\n",
    "        delay_time = data['Dmat'].copy() / signal_speed\n",
    "        np.fill_diagonal(delay_time, 0)\n",
    "        indices_ = np.arange(conn_weight.shape[1])\n",
    "        indices_ = np.tile(np.expand_dims(indices_, axis=0), (conn_weight.shape[0], 1))\n",
    "\n",
    "        self.node = brainmass.WilsonCowanStep(\n",
    "            80,\n",
    "            noise_E=brainmass.OUProcess(80, sigma=sigma, init=braintools.init.ZeroInit()),\n",
    "            noise_I=brainmass.OUProcess(80, sigma=sigma, init=braintools.init.ZeroInit()),\n",
    "        )\n",
    "        self.coupling = brainmass.DiffusiveCoupling(\n",
    "            self.node.prefetch_delay('rE', (delay_time * u.ms, indices_), init=braintools.init.Uniform(0, 0.05)),\n",
    "            self.node.prefetch('rE'),\n",
    "            conn_weight,\n",
    "            k=k\n",
    "        )\n",
    "\n",
    "    def update(self):\n",
    "        current = self.coupling()\n",
    "        rE = self.node(current)\n",
    "        return rE\n",
    "\n",
    "    def step_run(self, i):\n",
    "        with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):\n",
    "            return self.update()"
   ],
   "id": "731391424938d899",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def simulation(k, sigma):\n",
    "    net = Network(k=k, sigma=sigma)\n",
    "    brainstate.nn.init_all_states(net)\n",
    "    indices = np.arange(0, 6e3 * u.ms // brainstate.environ.get_dt())\n",
    "    exes = brainstate.transform.for_loop(net.step_run, indices)\n",
    "    fc = braintools.metric.functional_connectivity(exes)\n",
    "    return braintools.metric.matrix_correlation(target_fc, fc)"
   ],
   "id": "d1003d6d5c98bfb8",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model and Coupling\r\n \r\n We simulate 80 Wilson–Cowan nodes with OU noise on both E and I. Diffusive coupling is applied on `rE` via `DiffusiveCoupling`:\r\n \r\n - Global gain `k` scales `Cmat`. \r\n - Delays derive from `Dmat / signal_speed` and are handled with `prefetch_delay`. \r\n - The update returns the current excitatory activity `rE`, which forms the time series used for FC.\r\n"
   ],
   "id": "83e643ca4e35513f"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulation and Loss\r\n \r\n Given `k` (and fixed `sigma` here), we simulate ~6 seconds, compute model FC from the excitatory time series, then compute a similarity score:\r\n \r\n - `functional_connectivity(X)`: FC from time-by-node array `X`.\r\n - `matrix_correlation(A,B)`: Pearson correlation between vectorized FCs.\r\n - Loss: `1 - correlation` (maximize match ⇔ minimize loss).\r\n \r\n We wrap the evaluation to enable `vmap` over a batch of `k` values, and JIT-compile it for speed.\r\n"
   ],
   "id": "337bade4379f8b83"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "@brainstate.transform.jit\n",
    "def vmap_loss_fn(k):\n",
    "    return 1 - brainstate.transform.vmap2(lambda x: simulation(x, sigma=0.05))(k)"
   ],
   "id": "6da60654b81e48cc",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "opt = braintools.optim.NevergradOptimizer(\n",
    "    vmap_loss_fn, method='DE', n_sample=4, bounds={'k': [0.5, 3.0]}\n",
    ")\n",
    "opt.initialize()\n",
    "opt.minimize(n_iter=10)\n"
   ],
   "id": "b5faea0014ba586c",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Practical Tips and Extensions\r\n \r\n - Increase `n_iter` and adjust `method` (e.g., CMA, PSO, DE) for better searches.\r\n - Vectorize multiple parameters: extend `vmap_loss_fn` and `bounds` to include, e.g., `sigma` or `signal_speed`.\r\n - Start with shorter simulations for quick iterations, then refine with longer runs.\r\n - Set random seeds and consider fixed noise initial states for reproducibility.\r\n - Ensure JAX backend (CPU/GPU) is configured to benefit from `jit`/`vmap`.\r\n"
   ],
   "id": "2dfdcbe41cf8ad43"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
