{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Grid Search using ``vmap``",
   "id": "83b3643df6f32468"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "This notebook explores how network parameters impact the match between simulated functional connectivity (FC) and an empirical target FC.\n",
    "\n",
    "What you will see:\n",
    "- Load an HCP dataset and compute a target FC.\n",
    "- Define a Wilson–Cowan network with delayed, diffusive coupling and stochastic inputs.\n",
    "- Run simulations to obtain FC and measure correlation with the target.\n",
    "- Explore parameters in two ways: (1) grid search with `jit` + `vmap`, and (2) black‑box optimization via Nevergrad.\n",
    "\n",
    "Key parameters:\n",
    "- `k` (global coupling gain): scales inter‑regional coupling strength.\n",
    "- `sigma` (noise scale): controls OU noise magnitude on E/I populations.\n",
    "- `signal_speed` (m/s equivalent): affects inter‑regional delays via structural distances.\n",
    "\n",
    "Performance notes:\n",
    "- Simulations are vectorized with `jax.vmap` and compiled with `jax.jit` for speed.\n",
    "- FC computation and grid sweeps are memory intensive; reduce grid size on limited hardware."
   ],
   "id": "explain-intro-1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import brainmass\n",
    "import brainstate\n",
    "import braintools\n",
    "import brainunit as u\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ],
   "id": "b692973106adbc1a",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "brainstate.environ.set(dt=0.1 * u.ms)",
   "id": "72a0ea87f59d5d5c",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import os.path\n",
    "import kagglehub\n",
    "\n",
    "path = kagglehub.dataset_download(\"oujago/hcp-gw-data-samples\")\n",
    "data = braintools.file.msgpack_load(os.path.join(path, \"hcp-data-sample.msgpack\"))\n",
    "\n",
    "target_fc = [braintools.metric.functional_connectivity(x.T) for x in data['BOLDs']]\n",
    "target_fc = jnp.mean(jnp.asarray(target_fc), axis=0)"
   ],
   "id": "ff267f6bb5121ba3",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Environment and data:\n",
    "- `dt` sets the simulation time step (here 0.1 ms).\n",
    "- We load the `hcp` dataset (structural matrices `Cmat` for coupling and `Dmat` for distances, plus BOLD timeseries).\n",
    "- The target FC is the mean across subjects of pairwise correlations on BOLD signals. This will be our fitting target."
   ],
   "id": "explain-data-1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "class Network(brainstate.nn.Module):\n",
    "    def __init__(self, signal_speed=2., k=1., sigma=0.01):\n",
    "        super().__init__()\n",
    "\n",
    "        conn_weight = data['Cmat'].copy()\n",
    "        np.fill_diagonal(conn_weight, 0)\n",
    "        delay_time = data['Dmat'].copy() / signal_speed\n",
    "        np.fill_diagonal(delay_time, 0)\n",
    "        indices_ = np.arange(conn_weight.shape[1])\n",
    "        indices_ = np.tile(np.expand_dims(indices_, axis=0), (conn_weight.shape[0], 1))\n",
    "\n",
    "        self.node = brainmass.WilsonCowanStep(\n",
    "            80,\n",
    "            noise_E=brainmass.OUProcess(80, sigma=sigma, init=braintools.init.ZeroInit()),\n",
    "            noise_I=brainmass.OUProcess(80, sigma=sigma, init=braintools.init.ZeroInit()),\n",
    "        )\n",
    "        self.coupling = brainmass.DiffusiveCoupling(\n",
    "            self.node.prefetch_delay('rE', (delay_time * u.ms, indices_), init=braintools.init.Uniform(0, 0.05)),\n",
    "            self.node.prefetch('rE'),\n",
    "            conn_weight,\n",
    "            k=k\n",
    "        )\n",
    "\n",
    "    def update(self):\n",
    "        current = self.coupling()\n",
    "        rE = self.node(current)\n",
    "        return rE\n",
    "\n",
    "    def step_run(self, i):\n",
    "        with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):\n",
    "            return self.update()"
   ],
   "id": "731391424938d899",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Model overview:\n",
    "- Node dynamics: a Wilson–Cowan unit with E/I populations and OU noise (`sigma` controls noise scale).\n",
    "- Delays: derived from structural distances `Dmat` and `signal_speed` (ms); used via `prefetch_delay` on the E‑rate.\n",
    "- Coupling: diffusive coupling takes delayed E activity from other nodes, scales by `k`, and injects it as current.\n",
    "- `step_run(i)`: advances one step while maintaining correct time/index context."
   ],
   "id": "explain-model-1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def simulation(k, sigma):\n",
    "    net = Network(k=k, sigma=sigma)\n",
    "    brainstate.nn.init_all_states(net)\n",
    "    indices = np.arange(0, 6e3 * u.ms // brainstate.environ.get_dt())\n",
    "    exes = brainstate.transform.for_loop(net.step_run, indices)\n",
    "    fc = braintools.metric.functional_connectivity(exes)\n",
    "    return braintools.metric.matrix_correlation(target_fc, fc)"
   ],
   "id": "d1003d6d5c98bfb8",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "We sweep `k` and `sigma` over user‑defined ranges and evaluate correlation between simulated and target FC.\n",
    "\n",
    "How it works:\n",
    "- Define grids `all_ks` and `all_sigmas`.\n",
    "- Vectorize `simulation(k, sigma)` across both axes with nested `jax.vmap`.\n",
    "- Compile with `jax.jit` to amortize overhead across the grid evaluation.\n",
    "- Use `jax.block_until_ready` to ensure all computations complete before plotting.\n",
    "\n",
    "Practical tips:\n",
    "- Start with a coarse grid to locate promising regions; refine around peaks.\n",
    "- This step is memory hungry. Reduce grid size or simulation length if you see OOM.\n",
    "- Consider fixing `signal_speed` first, then scanning `k`/`sigma`."
   ],
   "id": "79a133a7f7a3d53"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "all_ks = jnp.linspace(0.5, 3.0, 4)\n",
    "all_sigmas = jnp.linspace(0.01, 0.2, 4)"
   ],
   "id": "43e60ae00e257f75",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "@brainstate.transform.jit\n",
    "def parameter_exploration(ks, sigmas):\n",
    "    results = brainstate.transform.vmap2(\n",
    "        lambda k: brainstate.transform.vmap2(lambda sigma: simulation(k, sigma))(sigmas)\n",
    "    )(ks)\n",
    "    return results"
   ],
   "id": "abce2a8698873509",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "correlations = jax.block_until_ready(parameter_exploration(all_ks, all_sigmas))",
   "id": "801c334cbe0a3277",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "plt.imshow(correlations, extent=(all_sigmas[0], all_sigmas[-1], all_ks[0], all_ks[-1]), origin='lower', aspect='auto')\n",
    "plt.colorbar(label='Correlation with target FC')\n",
    "plt.xlabel('Sigma')\n",
    "plt.ylabel('K')\n",
    "plt.title('Parameter Exploration')\n",
    "plt.show()"
   ],
   "id": "ec2a23f31d8b0bd4",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
