{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f0c8baa0",
   "metadata": {},
   "source": [
    "# Advanced Integration\n",
    "\n",
    "This page covers three topics beyond the basic *define-and-step* workflow:\n",
    "\n",
    "1. giving a sub-system its **own solver** with `IndependentIntegration`,\n",
    "2. the stochastic **`diffusion`** slot and where SDE support stands today, and\n",
    "3. **registering your own integrator** so it is available by name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d06f2ccd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-25T09:23:30.020359Z",
     "iopub.status.busy": "2026-05-25T09:23:30.020208Z",
     "iopub.status.idle": "2026-05-25T09:23:33.655332Z",
     "shell.execute_reply": "2026-05-25T09:23:33.654001Z"
    }
   },
   "outputs": [],
   "source": [
    "import brainstate\n",
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import braincell\n",
    "from braincell import DiffEqState, DiffEqModule\n",
    "from braincell.quad import get_integrator, register_integrator, get_registry"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83bfc3e0",
   "metadata": {},
   "source": [
    "## Mixing solvers with `IndependentIntegration`\n",
    "\n",
    "By default, one solver advances *every* `DiffEqState` in a model with a single\n",
    "shared `dt`. Sometimes a sub-system wants something different — fast voltage\n",
    "gating that needs exponential Euler while the rest of the cell runs an explicit\n",
    "scheme, or a calcium pool that prefers backward Euler.\n",
    "\n",
    "`IndependentIntegration` is the mixin for that. States owned by an\n",
    "`IndependentIntegration` sub-module are **filtered out** of the parent's\n",
    "integration loop; the sub-module instead advances its own states by calling\n",
    "`make_integration`, which dispatches to whatever solver it was constructed with:\n",
    "\n",
    "```python\n",
    "from braincell import DiffEqModule, IndependentIntegration\n",
    "\n",
    "class FastGate(IndependentIntegration, DiffEqModule):\n",
    "    def __init__(self):\n",
    "        super().__init__(solver='exp_euler')   # this sub-system's own solver\n",
    "\n",
    "    def compute_derivative(self, *args):\n",
    "        ...   # write derivatives as usual\n",
    "```\n",
    "\n",
    "Reach for `IndependentIntegration` only when you are building a sub-system that\n",
    "genuinely needs a different solver from its parent.\n",
    "\n",
    "> **Name clash.** Do not confuse `IndependentIntegration` (this mixin) with the\n",
    "> **`ind_exp_euler`** *solver*. The latter is the decoupled sibling of\n",
    "> `exp_euler` — it linearizes each `DiffEqState` independently rather than\n",
    "> building one global Jacobian — and has nothing to do with this mixin. Its\n",
    "> registry entry is shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "979e1257",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-25T09:23:33.660713Z",
     "iopub.status.busy": "2026-05-25T09:23:33.659982Z",
     "iopub.status.idle": "2026-05-25T09:23:33.664405Z",
     "shell.execute_reply": "2026-05-25T09:23:33.663724Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "name        : ind_exp_euler\n",
      "category    : exponential\n",
      "description : Independent exponential Euler step (per-state linearization).\n"
     ]
    }
   ],
   "source": [
    "# `ind_exp_euler` — a decoupled (per-state) exponential-Euler solver,\n",
    "# NOT the IndependentIntegration mixin despite the similar name\n",
    "entry = get_registry().entry(\"ind_exp_euler\")\n",
    "print(\"name        :\", entry.name)\n",
    "print(\"category    :\", entry.category)\n",
    "print(\"description :\", entry.description)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc038a64",
   "metadata": {},
   "source": [
    "## The `diffusion` slot: stochastic systems\n",
    "\n",
    "Every `DiffEqState` carries two solver-facing slots. So far we have only used\n",
    "`derivative` — the drift term $f(t, y)$ of an ODE. The second slot,\n",
    "`diffusion`, is the noise coefficient $g(t, y)$ of a stochastic differential\n",
    "equation\n",
    "\n",
    "$$dy = f(t, y)\\,dt + g(t, y)\\,dW.$$\n",
    "\n",
    "It defaults to `None`, which marks the state as deterministic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cb26f38b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-25T09:23:33.666572Z",
     "iopub.status.busy": "2026-05-25T09:23:33.666378Z",
     "iopub.status.idle": "2026-05-25T09:23:33.764381Z",
     "shell.execute_reply": "2026-05-25T09:23:33.763530Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "default diffusion: None\n",
      "drift     : [-0.]\n",
      "diffusion : [0.1]\n"
     ]
    }
   ],
   "source": [
    "s = DiffEqState(jnp.zeros(1))\n",
    "print(\"default diffusion:\", s.diffusion)   # None  ->  treated as an ODE\n",
    "\n",
    "# assigning a coefficient marks the state as stochastic (SDE drift + noise)\n",
    "s.derivative = -s.value          # drift  f(t, y)\n",
    "s.diffusion = jnp.ones(1) * 0.1  # noise  g(t, y)\n",
    "print(\"drift     :\", s.derivative)\n",
    "print(\"diffusion :\", s.diffusion)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c26f40ec",
   "metadata": {},
   "source": [
    "> **Status.** The `diffusion` slot is part of the protocol so that models\n",
    "> can *declare* stochastic dynamics, but the integrators shipped in\n",
    "> `braincell.quad` today read `derivative` and advance the **deterministic**\n",
    "> system — none of them consume `diffusion` yet. Setting it does not, on its\n",
    "> own, produce a stochastic trajectory; SDE stepping is reserved for\n",
    "> future SDE-aware solvers. Treat the slot as forward-looking API surface, not a\n",
    "> working stochastic integrator."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1adf253c",
   "metadata": {},
   "source": [
    "## Registering your own integrator\n",
    "\n",
    "The registry is open: decorate a step function with `@register_integrator` and\n",
    "it becomes available by name everywhere `get_integrator` is used, including the\n",
    "`solver=` argument of a cell.\n",
    "\n",
    "A step function receives the target `DiffEqModule` and is responsible for the\n",
    "full lifecycle. Here is a from-scratch forward Euler that mirrors how the\n",
    "built-in solvers are structured — note how it enumerates the module's\n",
    "`DiffEqState`s through the public `brainstate.graph` API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "14fe9aa8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-25T09:23:33.766525Z",
     "iopub.status.busy": "2026-05-25T09:23:33.766284Z",
     "iopub.status.idle": "2026-05-25T09:23:33.771214Z",
     "shell.execute_reply": "2026-05-25T09:23:33.770593Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "registered: True\n",
      "alias resolves: True\n"
     ]
    }
   ],
   "source": [
    "@register_integrator(\n",
    "    \"tutorial_euler\",\n",
    "    aliases=(\"tut_euler\",),\n",
    "    category=\"explicit\",\n",
    "    order=1,\n",
    "    description=\"Forward Euler built in the advanced integration tutorial.\",\n",
    "    override=True,   # keeps this cell safe to re-run in the same kernel\n",
    ")\n",
    "def tutorial_euler(target, *args):\n",
    "    dt = brainstate.environ.get(\"dt\")\n",
    "    target.pre_integral(*args)                                  # 1. before\n",
    "    states = brainstate.graph.states(target).filter(DiffEqState)\n",
    "    target.compute_derivative(*args)                            # 2. fill derivatives\n",
    "    for st in states.values():                                  # 3. y <- y + dt * f\n",
    "        st.value = st.value + dt * st.derivative\n",
    "    target.post_integral(*args)                                 # 4. after\n",
    "\n",
    "print(\"registered:\", \"tutorial_euler\" in get_registry().names())\n",
    "print(\"alias resolves:\", get_integrator(\"tut_euler\") is tutorial_euler)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc4f5db6",
   "metadata": {},
   "source": [
    "Now drive a model with it, exactly like any built-in solver, and check it\n",
    "against the analytic answer for the decay problem $y(t) = e^{-t/\\tau}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "22e12bbd",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-25T09:23:33.772964Z",
     "iopub.status.busy": "2026-05-25T09:23:33.772738Z",
     "iopub.status.idle": "2026-05-25T09:23:34.040792Z",
     "shell.execute_reply": "2026-05-25T09:23:34.039923Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "analytic        y(10) = 0.367879\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tutorial_euler  y(10) = 0.366958\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "built-in euler  y(10) = 0.366958\n"
     ]
    }
   ],
   "source": [
    "class Decay(brainstate.nn.Dynamics, DiffEqModule):\n",
    "    def __init__(self, tau=10.0):\n",
    "        super().__init__(in_size=1)\n",
    "        self.tau = tau\n",
    "\n",
    "    def init_state(self, *args):\n",
    "        self.y = DiffEqState(jnp.ones(1))\n",
    "\n",
    "    def compute_derivative(self, *args):\n",
    "        self.y.derivative = -self.y.value / self.tau\n",
    "\n",
    "\n",
    "def run(solver_name, dt=0.05, t_end=10.0):\n",
    "    model = Decay(tau=10.0)\n",
    "    brainstate.nn.init_all_states(model)\n",
    "    step = get_integrator(solver_name)\n",
    "    with brainstate.environ.context(dt=dt):\n",
    "        for i in range(int(t_end / dt)):\n",
    "            with brainstate.environ.context(t=i * dt):\n",
    "                step(model)\n",
    "    return float(model.y.value[0])\n",
    "\n",
    "\n",
    "exact = np.exp(-10.0 / 10.0)\n",
    "print(f\"analytic        y(10) = {exact:.6f}\")\n",
    "print(f\"tutorial_euler  y(10) = {run('tutorial_euler'):.6f}\")\n",
    "print(f\"built-in euler  y(10) = {run('euler'):.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b88bd893",
   "metadata": {},
   "source": [
    "Our hand-written `tutorial_euler` matches the built-in `euler` to the last\n",
    "digit — they implement the same scheme. From here, the registry's\n",
    "`override=` flag lets you replace an entry, and `unregister` removes one; see\n",
    "the [API reference](../apis/integration) for the full registry surface.\n",
    "\n",
    "## Recap\n",
    "\n",
    "- `IndependentIntegration` lets a sub-system run its own solver (distinct from\n",
    "  the similarly named `ind_exp_euler` solver).\n",
    "- `diffusion` declares SDE noise, but current solvers integrate the\n",
    "  deterministic drift only.\n",
    "- `@register_integrator` adds a named solver that plugs into the same\n",
    "  `get_integrator` / `solver=` machinery as everything else."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
