{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a760f7ec2a78",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T07:40:49.147414Z",
     "iopub.status.busy": "2026-06-19T07:40:49.147184Z",
     "iopub.status.idle": "2026-06-19T07:40:54.464833Z",
     "shell.execute_reply": "2026-06-19T07:40:54.463745Z"
    },
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import warnings\n",
    "import time\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintools\n",
    "import brainunit as u\n",
    "import brainmass\n",
    "from brainmass import objectives\n",
    "from brainstate.nn import Param\n",
    "brainstate.environ.set(dt=0.1 * u.ms)\n",
    "brainstate.random.seed(0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe6e66cfa34a",
   "metadata": {},
   "source": [
    "# Batch and Accelerate\n",
    "\n",
    "**Goal:** make `brainmass` runs fast — compile with `jit`, run an ensemble or a\n",
    "parameter grid in one shot with `vmap`, and choose the right loop primitive.\n",
    "\n",
    "`brainmass` is built on JAX, so the same code runs on CPU, GPU, or TPU. The\n",
    "speed-ups all come from three `brainstate.transform` primitives. This recipe\n",
    "shows when to reach for each, with a before/after timing.\n",
    "\n",
    ":::{note}\n",
    "**Never drive a model with a bare Python `for`/`while` loop when it runs\n",
    "repeatedly.** A Python loop executes op-by-op (dispatch overhead, no fusion) and\n",
    "re-traces the body every step; the `brainstate.transform` primitives lower the\n",
    "whole loop into one compiled XLA program, tracing the body only once. This guide\n",
    "uses the raw primitives directly to *teach* them — in normal use,\n",
    "`brainmass.Simulator` wraps `for_loop` for you.\n",
    ":::\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0de6d87fa5ba",
   "metadata": {},
   "source": [
    "## Which primitive?\n",
    "\n",
    "| Shape of the work | Primitive |\n",
    "| --- | --- |\n",
    "| Single step / one-shot call | `brainstate.transform.jit` |\n",
    "| Many steps, collect outputs | `brainstate.transform.for_loop` |\n",
    "| Many steps with an explicit carry | `brainstate.transform.scan` |\n",
    "| Batch over inputs / parameters | `brainstate.transform.vmap` |\n",
    "| Long rollout under autograd (BPTT) | `checkpointed_for_loop` / `checkpointed_scan` |\n",
    "\n",
    "`brainmass.Simulator` already composes `jit` + `for_loop` internally, so most of\n",
    "the time you just call `Simulator(model, dt).run(...)`. Reach for the raw\n",
    "primitives when you need a custom step or an explicit carry.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "718b5277bfeb",
   "metadata": {},
   "source": [
    "## `jit`: compile a step once\n",
    "\n",
    "A bare Python loop calls the model op-by-op and re-traces every iteration.\n",
    "Wrapping the step in `brainstate.transform.jit` compiles it once; subsequent\n",
    "calls reuse the compiled program. Here we time a hand-written rollout with and\n",
    "without `jit`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "be7b966963b5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T07:40:54.468219Z",
     "iopub.status.busy": "2026-06-19T07:40:54.467585Z",
     "iopub.status.idle": "2026-06-19T07:40:59.057376Z",
     "shell.execute_reply": "2026-06-19T07:40:59.056805Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uncompiled :  4357.3 ms\n",
      "jit        :    91.4 ms\n",
      "speed-up   :   47.7x\n"
     ]
    }
   ],
   "source": [
    "node = brainmass.HopfStep(in_size=64, a=0.25, w=0.3)\n",
    "node.init_all_states()\n",
    "\n",
    "def step_uncompiled(i):\n",
    "    with brainstate.environ.context(i=i, t=i * 0.1 * u.ms):\n",
    "        node.update()\n",
    "    return node.x.value\n",
    "\n",
    "step_jit = brainstate.transform.jit(step_uncompiled)\n",
    "\n",
    "n = 300\n",
    "# warm up the compiled version (first call traces + compiles)\n",
    "_ = step_jit(0); jax.block_until_ready(node.x.value)\n",
    "\n",
    "# uncompiled: every call dispatches op-by-op\n",
    "node.init_all_states()\n",
    "t0 = time.perf_counter()\n",
    "for i in range(n):\n",
    "    step_uncompiled(i)\n",
    "jax.block_until_ready(node.x.value)\n",
    "t_uncompiled = time.perf_counter() - t0\n",
    "\n",
    "# compiled: each call reuses the compiled program\n",
    "node.init_all_states()\n",
    "t0 = time.perf_counter()\n",
    "for i in range(n):\n",
    "    step_jit(i)\n",
    "jax.block_until_ready(node.x.value)\n",
    "t_jit = time.perf_counter() - t0\n",
    "\n",
    "print(f\"uncompiled : {t_uncompiled * 1e3:7.1f} ms\")\n",
    "print(f\"jit        : {t_jit * 1e3:7.1f} ms\")\n",
    "print(f\"speed-up   : {t_uncompiled / t_jit:6.1f}x\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc07cbf6c765",
   "metadata": {},
   "source": [
    "## `for_loop`: fuse the whole rollout\n",
    "\n",
    "Even a `jit`-ed step is still dispatched once per iteration from Python.\n",
    "`brainstate.transform.for_loop` lowers the **entire** loop into one XLA program —\n",
    "it traces the body once and stacks the per-step outputs for you. `State` (the\n",
    "model's hidden variables) is carried automatically. This is exactly what\n",
    "`Simulator.run` does under the hood.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "de7113cc29c9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T07:40:59.060191Z",
     "iopub.status.busy": "2026-06-19T07:40:59.059619Z",
     "iopub.status.idle": "2026-06-19T07:40:59.248003Z",
     "shell.execute_reply": "2026-06-19T07:40:59.247050Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "for_loop output: (300, 64) (time, regions)\n",
      "for_loop   :    79.6 ms  (55x vs uncompiled)\n"
     ]
    }
   ],
   "source": [
    "node = brainmass.HopfStep(in_size=64, a=0.25, w=0.3)\n",
    "node.init_all_states()\n",
    "\n",
    "def step(i):\n",
    "    with brainstate.environ.context(i=i, t=i * 0.1 * u.ms):\n",
    "        node.update()\n",
    "    return node.x.value\n",
    "\n",
    "run = brainstate.transform.jit(lambda: brainstate.transform.for_loop(step, jnp.arange(n)))\n",
    "xs = run(); jax.block_until_ready(xs)  # warm up\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "xs = run(); jax.block_until_ready(xs)\n",
    "t_forloop = time.perf_counter() - t0\n",
    "\n",
    "print(\"for_loop output:\", xs.shape, \"(time, regions)\")\n",
    "print(f\"for_loop   : {t_forloop * 1e3:7.1f} ms  ({t_uncompiled / t_forloop:.0f}x vs uncompiled)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f5b28f9858a",
   "metadata": {},
   "source": [
    "In practice you never write that loop — `brainmass.Simulator` is the same\n",
    "`jit` + `for_loop`, validated and with monitors / transient / units handled.\n",
    "The timing below should match the hand-written `for_loop` above.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "700e058b0c8b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T07:40:59.250189Z",
     "iopub.status.busy": "2026-06-19T07:40:59.249936Z",
     "iopub.status.idle": "2026-06-19T07:40:59.464943Z",
     "shell.execute_reply": "2026-06-19T07:40:59.464057Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simulator output: (300, 64)\n",
      "Simulator  :   103.6 ms\n"
     ]
    }
   ],
   "source": [
    "sim = brainmass.Simulator(node, dt=0.1 * u.ms)\n",
    "res = sim.run(n * 0.1 * u.ms, monitors=['x'])  # warm up + compile\n",
    "\n",
    "t0 = time.perf_counter()\n",
    "res = sim.run(n * 0.1 * u.ms, monitors=['x'])\n",
    "jax.block_until_ready(res['x'])\n",
    "print(\"Simulator output:\", res['x'].shape)\n",
    "print(f\"Simulator  : {(time.perf_counter() - t0) * 1e3:7.1f} ms\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d0069809b17",
   "metadata": {},
   "source": [
    "## `scan`: thread an explicit carry\n",
    "\n",
    "When you need to carry a value *alongside* the model's `State`\n",
    "(`f(carry, x) -> (carry, y)`), use `brainstate.transform.scan`. A typical use is\n",
    "feeding a time-varying external drive into the model and accumulating a running\n",
    "statistic. Here the carry is a running sum of the output.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b891e66a3c79",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T07:40:59.467268Z",
     "iopub.status.busy": "2026-06-19T07:40:59.466987Z",
     "iopub.status.idle": "2026-06-19T07:40:59.797588Z",
     "shell.execute_reply": "2026-06-19T07:40:59.796926Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scan stacked outputs: (300, 8)\n",
      "carried running sum  : (8,)\n"
     ]
    }
   ],
   "source": [
    "node = brainmass.HopfStep(in_size=8, a=0.25, w=0.3)\n",
    "node.init_all_states()\n",
    "\n",
    "drive = 0.05 * jnp.sin(2 * jnp.pi * jnp.arange(n) / n)[:, None]  # (time, 1)\n",
    "\n",
    "def body(carry, inp):\n",
    "    running_sum = carry\n",
    "    with brainstate.environ.context(t=0. * u.ms):\n",
    "        node.update(inp)          # external drive into x\n",
    "    x = node.x.value\n",
    "    return running_sum + x, x     # (new carry, per-step output)\n",
    "\n",
    "total, xs = brainstate.transform.scan(body, jnp.zeros(8), drive)\n",
    "print(\"scan stacked outputs:\", xs.shape)\n",
    "print(\"carried running sum  :\", total.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "117d76a241df",
   "metadata": {},
   "source": [
    "## `vmap`: batch an ensemble in one run\n",
    "\n",
    "`brainstate.transform.vmap` adds a batch axis to a whole computation. There are\n",
    "two common patterns.\n",
    "\n",
    "**Batched initial conditions** — pass `batch_size=B` to\n",
    "`Simulator.run`; it calls `init_all_states(batch_size=B)` and the outputs gain a\n",
    "leading batch axis. This runs `B` independent trajectories (e.g. a noise\n",
    "ensemble) in a single compiled program.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c24114039fbe",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T07:40:59.799714Z",
     "iopub.status.busy": "2026-06-19T07:40:59.799441Z",
     "iopub.status.idle": "2026-06-19T07:41:00.237686Z",
     "shell.execute_reply": "2026-06-19T07:41:00.236946Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batched trajectory: (2000, 16, 4) (time, batch, regions)\n"
     ]
    }
   ],
   "source": [
    "node = brainmass.HopfStep(\n",
    "    in_size=4, a=0.1, w=0.3,\n",
    "    noise_x=brainmass.OUProcess(4, sigma=0.1, tau=10 * u.ms),\n",
    ")\n",
    "brainstate.random.seed(0)\n",
    "res = brainmass.Simulator(node, dt=0.1 * u.ms).run(\n",
    "    200 * u.ms, monitors=['x'], batch_size=16,\n",
    ")\n",
    "print(\"batched trajectory:\", res['x'].shape, \"(time, batch, regions)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95dca7a94939",
   "metadata": {},
   "source": [
    "**Batched parameters** — `vmap` a function that builds the model *inside*\n",
    "itself over a parameter array. This runs one simulation per parameter value, all\n",
    "fused. (See {doc}`/howto/parameter_sweeps` for the full grid-sweep recipe.)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "75cfc4efabd4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T07:41:00.239484Z",
     "iopub.status.busy": "2026-06-19T07:41:00.239262Z",
     "iopub.status.idle": "2026-06-19T07:41:00.974774Z",
     "shell.execute_reply": "2026-06-19T07:41:00.974043Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "a = 0.10  ->  amplitude 0.329\n",
      "a = 0.30  ->  amplitude 0.561\n",
      "a = 0.50  ->  amplitude 0.721\n",
      "a = 0.70  ->  amplitude 0.849\n",
      "a = 0.90  ->  amplitude 0.958\n",
      "a = 1.10  ->  amplitude 1.055\n",
      "a = 1.30  ->  amplitude 1.142\n",
      "a = 1.50  ->  amplitude 1.223\n"
     ]
    }
   ],
   "source": [
    "a_values = jnp.linspace(0.1, 1.5, 8)\n",
    "\n",
    "def amplitude_for(a):\n",
    "    node = brainmass.HopfStep(in_size=1, a=a, w=0.3,\n",
    "                              init_x=braintools.init.Constant(0.5))\n",
    "    r = brainmass.Simulator(node, dt=0.1 * u.ms).run(\n",
    "        150 * u.ms, monitors=['x'], transient=50 * u.ms)\n",
    "    x = u.get_magnitude(r['x'])[:, 0]\n",
    "    return jnp.sqrt(jnp.mean(x ** 2)) * jnp.sqrt(2.0)  # RMS amplitude\n",
    "\n",
    "amps = brainstate.transform.vmap(amplitude_for)(a_values)\n",
    "for a, amp in zip(a_values, amps):\n",
    "    print(f\"a = {float(a):.2f}  ->  amplitude {float(amp):.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12c82771c2a5",
   "metadata": {},
   "source": [
    "## GPU / TPU\n",
    "\n",
    "The code above is device-agnostic — JAX dispatches to whatever backend `jaxlib`\n",
    "was built for. To use an accelerator:\n",
    "\n",
    "- **GPU:** install a CUDA build of `jaxlib` (`pip install brainmass` then\n",
    "  `pip install -U \"jax[cuda12]\"`). No code change.\n",
    "- **TPU:** install `jax[tpu]`. No code change.\n",
    "\n",
    "Check the active backend at runtime:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "df08c0ed0aa2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-06-19T07:41:00.976884Z",
     "iopub.status.busy": "2026-06-19T07:41:00.976739Z",
     "iopub.status.idle": "2026-06-19T07:41:00.980598Z",
     "shell.execute_reply": "2026-06-19T07:41:00.979814Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "JAX default backend: cpu\n",
      "devices: [CpuDevice(id=0)]\n"
     ]
    }
   ],
   "source": [
    "print(\"JAX default backend:\", jax.default_backend())\n",
    "print(\"devices:\", jax.devices())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c99e889beccf",
   "metadata": {},
   "source": [
    "Two practical tips for accelerators:\n",
    "\n",
    "- **`vmap` is how you fill a GPU.** A single small network barely uses a GPU; a\n",
    "  `vmap`-ed ensemble or parameter grid keeps it busy and amortises the launch\n",
    "  cost.\n",
    "- **`block_until_ready` for honest timing.** JAX is asynchronous — without it you\n",
    "  time the dispatch, not the compute (as we did above).\n",
    "\n",
    "## Long rollouts under autograd\n",
    "\n",
    "If you backpropagate through a *long* simulation (backprop-through-time), storing\n",
    "every step's activations can exhaust memory. Swap `for_loop` /  `scan` for\n",
    "`brainstate.transform.checkpointed_for_loop` / `checkpointed_scan`: same\n",
    "semantics, but activations are rematerialised on the backward pass (tune `base`)\n",
    "to bound peak memory at the cost of recomputation. Reach for these *only* when a\n",
    "reverse-mode gradient through a long run would otherwise run out of memory —\n",
    "otherwise prefer plain `for_loop` / `scan`.\n",
    "\n",
    "## Next steps\n",
    "\n",
    "- {doc}`/howto/parameter_sweeps` — sweep a parameter grid with `vmap`.\n",
    "- {doc}`/tutorials/06_fitting_with_gradients` — gradients through a `Simulator`.\n",
    "- {doc}`/concepts/index` — how the run loop and transforms fit together.\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
