{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c7f153617cc46383",
   "metadata": {},
   "source": [
    "# JIT and Compilation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d435eaf6d45d207d",
   "metadata": {},
   "source": [
    "`brainstate.transform.jit` extends `jax.jit` with state tracking and extra control\n",
    "surfaces. This guide highlights how BrainState JIT differs from plain JAX JIT,\n",
    "when to prefer each API, and how to decompose modules with\n",
    "`brainstate.graph.treefy_split` or `brainstate.graph.treefy_states`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2878f39135473a0d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:09.108494Z",
     "start_time": "2025-10-11T02:27:07.644640Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:49.912968Z",
     "iopub.status.busy": "2026-05-30T17:17:49.912670Z",
     "iopub.status.idle": "2026-05-30T17:17:51.996484Z",
     "shell.execute_reply": "2026-05-30T17:17:51.995258Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "417243151b19f37d",
   "metadata": {},
   "source": [
    "## Why BrainState JIT?\n",
    "\n",
    "`brainstate.transform.jit` understands `State` objects and automatically wires\n",
    "read/write traces into the compiled function. The returned object is a\n",
    "`JittedFunction` with helper methods such as `compile`, `clear_cache`, and\n",
    "`origin_fun`. Pure functions still work, but stateful modules are first-class\n",
    "citizens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "559010de46df0d34",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:47.059353Z",
     "start_time": "2025-10-11T02:27:46.897830Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:51.999283Z",
     "iopub.status.busy": "2026-05-30T17:17:51.998870Z",
     "iopub.status.idle": "2026-05-30T17:17:52.154715Z",
     "shell.execute_reply": "2026-05-30T17:17:52.154049Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Array([0.00671535, 0.03505242, 0.17300805, 0.69314724, 1.839675  ,\n",
       "       3.368386  , 5.0067153 ], dtype=float32)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@brainstate.transform.jit\n",
    "def softplus(x: jax.Array) -> jax.Array:\n",
    "    return jnp.log1p(jnp.exp(-jnp.abs(x))) + jnp.maximum(x, 0)\n",
    "\n",
    "xs = jnp.linspace(-5.0, 5.0, 7)\n",
    "softplus(xs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad3abff8d147902b",
   "metadata": {},
   "source": [
    "Subsequent calls reuse the compiled executable. If you disable JIT globally\n",
    "(`jax.config.jax_disable_jit = True`), BrainState falls back to the original\n",
    "Python implementation automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a887dd42c06413f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:08:54.810110Z",
     "start_time": "2025-10-11T03:08:54.556595Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:52.156786Z",
     "iopub.status.busy": "2026-05-30T17:17:52.156550Z",
     "iopub.status.idle": "2026-05-30T17:17:52.394253Z",
     "shell.execute_reply": "2026-05-30T17:17:52.393384Z"
    }
   },
   "outputs": [],
   "source": [
    "with jax.disable_jit():\n",
    "    softplus(xs * 2.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ff5a6da8bf6bed3",
   "metadata": {},
   "source": [
    "## Stateful modules with zero boilerplate\n",
    "\n",
    "BrainState keeps modules stateful inside compiled code. Below, a running-mean\n",
    "tracker updates hidden state at each call without any manual intervention."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "744bb1d6353b52d3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:49.392064Z",
     "start_time": "2025-10-11T02:27:49.279600Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:52.396690Z",
     "iopub.status.busy": "2026-05-30T17:17:52.396396Z",
     "iopub.status.idle": "2026-05-30T17:17:52.557603Z",
     "shell.execute_reply": "2026-05-30T17:17:52.556754Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0: mean=1.50\n",
      "step 1: mean=2.00\n",
      "step 2: mean=2.50\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(30.0, 12)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class RunningMean(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.sum = brainstate.HiddenState(jnp.array(0.0))\n",
    "        self.count = brainstate.HiddenState(jnp.array(0))\n",
    "\n",
    "    def __call__(self, batch: jax.Array) -> jax.Array:\n",
    "        self.sum.value += jnp.sum(batch)\n",
    "        self.count.value += batch.size\n",
    "        return self.sum.value / self.count.value\n",
    "\n",
    "\n",
    "tracker = RunningMean()\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def update_running_mean(batch: jax.Array) -> jax.Array:\n",
    "    return tracker(batch)\n",
    "\n",
    "for step in range(3):\n",
    "    data = jnp.arange(4.0) + step\n",
    "    print(f'step {step}: mean={float(update_running_mean(data)):.2f}')\n",
    "\n",
    "float(tracker.sum.value), int(tracker.count.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12e778e9200f783f",
   "metadata": {},
   "source": [
    "The hidden states remain in sync because BrainState records and replays the\n",
    "state updates around the compiled executable."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "383c4c2dfb8be480",
   "metadata": {},
   "source": [
    "## Extra controls exposed by `JittedFunction`\n",
    "\n",
    "Unlike bare `jax.jit`, BrainState's wrapper exposes runtime helpers. You can\n",
    "precompile executables or drop cached traces explicitly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b5ae12bd805723d5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:50.994238Z",
     "start_time": "2025-10-11T02:27:50.926425Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:52.559861Z",
     "iopub.status.busy": "2026-05-30T17:17:52.559643Z",
     "iopub.status.idle": "2026-05-30T17:17:52.645835Z",
     "shell.execute_reply": "2026-05-30T17:17:52.644973Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([1.3132617, 1.3132617, 1.3132617, 1.3132617], dtype=float32)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "softplus.compile(jnp.ones((4,)))\n",
    "softplus(jnp.ones((4,)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "62102c248d14d73e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:51.758659Z",
     "start_time": "2025-10-11T02:27:51.681401Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:52.647768Z",
     "iopub.status.busy": "2026-05-30T17:17:52.647508Z",
     "iopub.status.idle": "2026-05-30T17:17:52.732186Z",
     "shell.execute_reply": "2026-05-30T17:17:52.731577Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0.3132617, 0.474077 , 0.6931472, 0.974077 , 1.3132617], dtype=float32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "softplus.clear_cache()\n",
    "softplus(jnp.linspace(-1.0, 1.0, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e2f2fade4a53f2a",
   "metadata": {},
   "source": [
    "## Working directly with `jax.jit`\n",
    "\n",
    "`brainstate.transform.jit` is convenient precisely because it removes a chore. With raw\n",
    "`jax.jit` you must make the computation *pure*: pull every piece of mutable state out of the\n",
    "model, pass it in as an argument, and return its updated value so you can feed it back on the\n",
    "next call. The running mean below is rewritten as a pure function of `(carry, batch)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c6ff773afe7f9061",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:53.100760Z",
     "start_time": "2025-10-11T02:27:53.037279Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:52.734419Z",
     "iopub.status.busy": "2026-05-30T17:17:52.734166Z",
     "iopub.status.idle": "2026-05-30T17:17:52.817854Z",
     "shell.execute_reply": "2026-05-30T17:17:52.816900Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0: mean=1.50\n",
      "step 1: mean=2.00\n",
      "step 2: mean=2.50\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(30.0, 12)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def running_mean_pure(carry, batch):\n",
    "    total, count = carry\n",
    "    total = total + jnp.sum(batch)\n",
    "    count = count + batch.size\n",
    "    return (total, count), total / count\n",
    "\n",
    "\n",
    "jax_jitted = jax.jit(running_mean_pure)\n",
    "\n",
    "carry = (jnp.array(0.0), jnp.array(0))   # the state, threaded by hand\n",
    "for step in range(3):\n",
    "    batch = jnp.arange(4.0) + step\n",
    "    carry, mean = jax_jitted(carry, batch)\n",
    "    print(f'step {step}: mean={float(mean):.2f}')\n",
    "\n",
    "total, count = carry\n",
    "float(total), int(count)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c214903e079bae9",
   "metadata": {},
   "source": [
    "The JAX version works, but every stateful quantity has to be threaded through the call signature\n",
    "by hand \u2014 and a real model may have dozens. `brainstate.transform.jit` lets you keep the natural\n",
    "stateful module and does this bookkeeping for you. (For explicit graph/state manipulation when\n",
    "you *do* want it, see `brainstate.graph.treefy_split` / `treefy_merge` and the\n",
    "[graph how-to](../../how_to/inspect_and_edit_state_graph.ipynb).)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "482d276e85d9881d",
   "metadata": {},
   "source": [
    "## `treefy_split` vs `treefy_states`\n",
    "\n",
    "Both helpers live in `brainstate.graph` but serve different purposes:\n",
    "\n",
    "- **`treefy_split`** \u2192 returns `(graph_def, state_tree1, state_tree2, ...)`. Use\n",
    "  it when you need to rebuild modules (e.g. JAX interop or serialising complete\n",
    "  graphs).\n",
    "- **`treefy_states`** \u2192 returns one or more state trees without the graph\n",
    "  definition. It's the lightweight choice when you only need a PyTree of\n",
    "  parameters for optimisation or checkpointing.\n",
    "\n",
    "\n",
    "See also [BrainState Graph and Node System](../../how_to/inspect_and_edit_state_graph.ipynb) for more details of how to use these interfaces."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c54411d91700a6cd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:28:35.756526Z",
     "start_time": "2025-10-11T02:28:35.729321Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:52.819853Z",
     "iopub.status.busy": "2026-05-30T17:17:52.819688Z",
     "iopub.status.idle": "2026-05-30T17:17:52.852430Z",
     "shell.execute_reply": "2026-05-30T17:17:52.851540Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "treefy_split param paths: [('bias',), ('weight',)]\n",
      "treefy_states param paths: [('bias',), ('weight',)]\n"
     ]
    }
   ],
   "source": [
    "class TinyLinear(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.weight = brainstate.ParamState(jnp.array([[1.0]]))\n",
    "        self.bias = brainstate.ParamState(jnp.array([0.0]))\n",
    "\n",
    "    def __call__(self, x: jax.Array) -> jax.Array:\n",
    "        return x @ self.weight.value + self.bias.value\n",
    "\n",
    "\n",
    "lin = TinyLinear()\n",
    "\n",
    "# Split into graph + states (useful for reconstruction / JAX interop)\n",
    "lin_graph, param_tree, other_states = brainstate.graph.treefy_split(\n",
    "    lin, brainstate.ParamState, ...,\n",
    ")\n",
    "print('treefy_split param paths:', list(param_tree.to_flat().keys()))\n",
    "\n",
    "# Fetch only the parameter tree (perfect for gradient updates)\n",
    "params_only = brainstate.graph.treefy_states(lin, brainstate.ParamState)\n",
    "print('treefy_states param paths:', list(params_only.to_flat().keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ac3c0c4fde23a995",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:31:05.831079Z",
     "start_time": "2025-10-11T02:31:05.813216Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:52.854874Z",
     "iopub.status.busy": "2026-05-30T17:17:52.854675Z",
     "iopub.status.idle": "2026-05-30T17:17:53.286016Z",
     "shell.execute_reply": "2026-05-30T17:17:53.285392Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 2.5\n",
      "grad bias TreefyState(\n",
      "  type=<class 'brainstate.ParamState'>,\n",
      "  value=Array([-3.], dtype=float32),\n",
      "  _hooks_manager=<brainstate._state_hook_manager.HookManager object at 0x7ebe801356e0>,\n",
      "  tag=None\n",
      ")\n",
      "grad weight TreefyState(\n",
      "  type=<class 'brainstate.ParamState'>,\n",
      "  value=Array([[-2.]], dtype=float32),\n",
      "  _hooks_manager=<brainstate._state_hook_manager.HookManager object at 0x7ebe801350f0>,\n",
      "  tag=None\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Example: compute gradients w.r.t. ParamState using brainstate.transform.grad\n",
    "def mse_loss(params, others, x):\n",
    "    lin_recovered = brainstate.graph.treefy_merge(lin_graph, params, others)\n",
    "    pred = lin_recovered(x)\n",
    "    target = 2.0 * x + 1.0\n",
    "    return jnp.mean((pred - target) ** 2)\n",
    "\n",
    "loss_grad = jax.value_and_grad(mse_loss)\n",
    "\n",
    "(loss_value, grads) = loss_grad(param_tree, other_states, jnp.array([[0.0], [1.0]]))\n",
    "print('loss:', float(loss_value))\n",
    "for path, g in grads.items():\n",
    "    print('grad', path, g)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb4e5442cf6571ab",
   "metadata": {},
   "source": [
    "`treefy_states` drops directly into optimisation pipelines: you obtain a PyTree\n",
    "keyed by parameter paths without carrying the `GraphDef` unless you plan to\n",
    "reconstruct the module elsewhere."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e8df0750d1f98a6",
   "metadata": {},
   "source": [
    "## Static arguments still apply\n",
    "\n",
    "Static-argument handling mirrors `jax.jit`. The example below specialises the\n",
    "compiled program by polynomial degree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7ce5c6b92dd21e7e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:31:29.171041Z",
     "start_time": "2025-10-11T02:31:29.114248Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-30T17:17:53.287922Z",
     "iopub.status.busy": "2026-05-30T17:17:53.287675Z",
     "iopub.status.idle": "2026-05-30T17:17:53.351022Z",
     "shell.execute_reply": "2026-05-30T17:17:53.350101Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 6. 34.] [ 6. 34.] [10. 98.]\n"
     ]
    }
   ],
   "source": [
    "@brainstate.transform.jit(static_argnums=1)\n",
    "def polynomial_series(x: jax.Array, degree: int) -> jax.Array:\n",
    "    powers = [x ** (i + 1) for i in range(degree)]\n",
    "    coeffs = jnp.arange(1, degree + 1, dtype=x.dtype)\n",
    "    return jnp.tensordot(coeffs, jnp.stack(powers, axis=0), axes=1)\n",
    "\n",
    "\n",
    "p1 = polynomial_series(jnp.array([1.0, 2.0]), 3)\n",
    "p2 = polynomial_series(jnp.array([1.0, 2.0]), 3)\n",
    "p3 = polynomial_series(jnp.array([1.0, 2.0]), 4)\n",
    "print(p1, p2, p3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a9d14c53cdd50e9",
   "metadata": {},
   "source": [
    "## Which API should you choose?\n",
    "\n",
    "| Scenario | `brainstate.transform.jit` | `jax.jit` |\n",
    "| -------- | -------------------------- | --------- |\n",
    "| Stateful BrainState modules | \u2705 Zero boilerplate | \u26a0\ufe0f Requires `treefy_split` and manual state threading |\n",
    "| Pure stateless functions | \u2705 Works (with helper methods) | \u2705 Often the leanest choice |\n",
    "| Need `compile()` / `clear_cache()` | \u2705 Built-in | \u274c Not available |\n",
    "| Custom sharding / device placement | \u2705 Same signature as `jax.jit` | \u2705 |\n",
    "\n",
    "`treefy_split` is the workhorse when you need a `GraphDef` for reconstruction or\n",
    "JAX interop. `treefy_states` is the light option for extracting parameter\n",
    "PyTrees, for example before calling `brainstate.transform.grad` or saving a\n",
    "checkpoint."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a041357a47f07cd7",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- `brainstate.transform.jit` tracks BrainState `State` objects automatically and\n",
    "  returns a `JittedFunction` with extra controls.\n",
    "- `jax.jit` still works, but you must explicitly split and merge module state.\n",
    "- `graph.treefy_split` produces `(graph_def, state_tree1, state_tree2, \u2026)` for\n",
    "  reconstruction; `graph.treefy_states` returns just the requested state trees.\n",
    "- Choose the interface that matches your workflow: use BrainState JIT for\n",
    "  module-centric code, drop down to JAX primitives when integrating with other\n",
    "  systems."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}