{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c7f153617cc46383",
   "metadata": {},
   "source": [
    "# JIT Compilation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d435eaf6d45d207d",
   "metadata": {},
   "source": [
    "`brainstate.transform.jit` extends `jax.jit` with state tracking and extra control\n",
    "surfaces. This guide highlights how BrainState JIT differs from plain JAX JIT,\n",
    "when to prefer each API, and how to decompose modules with\n",
    "`brainstate.graph.treefy_split` or `brainstate.graph.treefy_states`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2878f39135473a0d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:09.108494Z",
     "start_time": "2025-10-11T02:27:07.644640Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "417243151b19f37d",
   "metadata": {},
   "source": [
    "## Why BrainState JIT?\n",
    "\n",
    "`brainstate.transform.jit` understands `State` objects and automatically wires\n",
    "read/write traces into the compiled function. The returned object is a\n",
    "`JittedFunction` with helper methods such as `compile`, `clear_cache`, and\n",
    "`origin_fun`. Pure functions still work, but stateful modules are first-class\n",
    "citizens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "559010de46df0d34",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:47.059353Z",
     "start_time": "2025-10-11T02:27:46.897830Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0.00671535, 0.03505242, 0.17300805, 0.69314724, 1.839675  ,\n",
       "       3.368386  , 5.0067153 ], dtype=float32)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@brainstate.transform.jit\n",
    "def softplus(x: jax.Array) -> jax.Array:\n",
    "    return jnp.log1p(jnp.exp(-jnp.abs(x))) + jnp.maximum(x, 0)\n",
    "\n",
    "xs = jnp.linspace(-5.0, 5.0, 7)\n",
    "softplus(xs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad3abff8d147902b",
   "metadata": {},
   "source": [
    "Subsequent calls reuse the compiled executable. If you disable JIT globally\n",
    "(`jax.config.jax_disable_jit = True`), BrainState falls back to the original\n",
    "Python implementation automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a887dd42c06413f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:08:54.810110Z",
     "start_time": "2025-10-11T03:08:54.556595Z"
    }
   },
   "outputs": [],
   "source": [
    "with jax.disable_jit():\n",
    "    softplus(xs * 2.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ff5a6da8bf6bed3",
   "metadata": {},
   "source": [
    "## Stateful modules with zero boilerplate\n",
    "\n",
    "BrainState keeps modules stateful inside compiled code. Below, a running-mean\n",
    "tracker updates hidden state at each call without any manual intervention."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "744bb1d6353b52d3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:49.392064Z",
     "start_time": "2025-10-11T02:27:49.279600Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0: mean=1.50\n",
      "step 1: mean=2.00\n",
      "step 2: mean=2.50\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(30.0, 12)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class RunningMean(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.sum = brainstate.HiddenState(jnp.array(0.0))\n",
    "        self.count = brainstate.HiddenState(jnp.array(0))\n",
    "\n",
    "    def __call__(self, batch: jax.Array) -> jax.Array:\n",
    "        self.sum.value += jnp.sum(batch)\n",
    "        self.count.value += batch.size\n",
    "        return self.sum.value / self.count.value\n",
    "\n",
    "\n",
    "tracker = RunningMean()\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def update_running_mean(batch: jax.Array) -> jax.Array:\n",
    "    return tracker(batch)\n",
    "\n",
    "for step in range(3):\n",
    "    data = jnp.arange(4.0) + step\n",
    "    print(f'step {step}: mean={float(update_running_mean(data)):.2f}')\n",
    "\n",
    "float(tracker.sum.value), int(tracker.count.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12e778e9200f783f",
   "metadata": {},
   "source": [
    "The hidden states remain in sync because BrainState records and replays the\n",
    "state updates around the compiled executable."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "383c4c2dfb8be480",
   "metadata": {},
   "source": [
    "## Extra controls exposed by `JittedFunction`\n",
    "\n",
    "Unlike bare `jax.jit`, BrainState's wrapper exposes runtime helpers. You can\n",
    "precompile executables or drop cached traces explicitly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b5ae12bd805723d5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:50.994238Z",
     "start_time": "2025-10-11T02:27:50.926425Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([1.3132617, 1.3132617, 1.3132617, 1.3132617], dtype=float32)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "softplus.compile(jnp.ones((4,)))\n",
    "softplus(jnp.ones((4,)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "62102c248d14d73e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:51.758659Z",
     "start_time": "2025-10-11T02:27:51.681401Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0.3132617, 0.474077 , 0.6931472, 0.974077 , 1.3132617], dtype=float32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "softplus.clear_cache()\n",
    "softplus(jnp.linspace(-1.0, 1.0, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e2f2fade4a53f2a",
   "metadata": {},
   "source": [
    "## Working directly with `jax.jit`\n",
    "\n",
    "If you prefer raw JAX primitives you can still make modules jit-friendly by\n",
    "splitting them into pure stateless functions. `brainstate.graph.treefy_split`\n",
    "returns a `GraphDef` plus one or more state trees that you must thread manually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c6ff773afe7f9061",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:27:53.100760Z",
     "start_time": "2025-10-11T02:27:53.037279Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0: mean=1.50\n",
      "step 1: mean=2.00\n",
      "step 2: mean=2.50\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(12, 30.0)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = RunningMean()\n",
    "\n",
    "graph_def, hidden_state_tree = brainstate.graph.treefy_split(model, brainstate.HiddenState)\n",
    "\n",
    "\n",
    "def running_mean_stateless(state_tree, batch):\n",
    "    module = brainstate.graph.treefy_merge(graph_def, state_tree)\n",
    "    out = module(batch)\n",
    "    new_state_tree = brainstate.graph.treefy_states(module, brainstate.HiddenState)\n",
    "    return out, new_state_tree\n",
    "\n",
    "\n",
    "jax_jitted = jax.jit(running_mean_stateless)\n",
    "\n",
    "state_tree = hidden_state_tree\n",
    "for step in range(3):\n",
    "    batch = jnp.arange(4.0) + step\n",
    "    mean, state_tree = jax_jitted(state_tree, batch)\n",
    "    print(f'step {step}: mean={float(mean):.2f}')\n",
    "\n",
    "int(state_tree['count'].value), float(state_tree['sum'].value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c214903e079bae9",
   "metadata": {},
   "source": [
    "The JAX version works, but you are responsible for threading state containers and\n",
    "reconstructing modules yourself."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "482d276e85d9881d",
   "metadata": {},
   "source": [
    "## `treefy_split` vs `treefy_states`\n",
    "\n",
    "Both helpers live in `brainstate.graph` but serve different purposes:\n",
    "\n",
    "- **`treefy_split`** → returns `(graph_def, state_tree1, state_tree2, ...)`. Use\n",
    "  it when you need to rebuild modules (e.g. JAX interop or serialising complete\n",
    "  graphs).\n",
    "- **`treefy_states`** → returns one or more state trees without the graph\n",
    "  definition. It's the lightweight choice when you only need a PyTree of\n",
    "  parameters for optimisation or checkpointing.\n",
    "\n",
    "\n",
    "See also [BrainState Graph and Node System](../utilities/01_graph_operations.ipynb) for more details of how to use these interfaces."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c54411d91700a6cd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:28:35.756526Z",
     "start_time": "2025-10-11T02:28:35.729321Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "treefy_split param paths: [('bias',), ('weight',)]\n",
      "treefy_states param paths: [('bias',), ('weight',)]\n"
     ]
    }
   ],
   "source": [
    "class TinyLinear(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.weight = brainstate.ParamState(jnp.array([[1.0]]))\n",
    "        self.bias = brainstate.ParamState(jnp.array([0.0]))\n",
    "\n",
    "    def __call__(self, x: jax.Array) -> jax.Array:\n",
    "        return x @ self.weight.value + self.bias.value\n",
    "\n",
    "\n",
    "lin = TinyLinear()\n",
    "\n",
    "# Split into graph + states (useful for reconstruction / JAX interop)\n",
    "lin_graph, param_tree, other_states = brainstate.graph.treefy_split(\n",
    "    lin, brainstate.ParamState, ...,\n",
    ")\n",
    "print('treefy_split param paths:', list(param_tree.to_flat().keys()))\n",
    "\n",
    "# Fetch only the parameter tree (perfect for gradient updates)\n",
    "params_only = brainstate.graph.treefy_states(lin, brainstate.ParamState)\n",
    "print('treefy_states param paths:', list(params_only.to_flat().keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ac3c0c4fde23a995",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:31:05.831079Z",
     "start_time": "2025-10-11T02:31:05.813216Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 2.5\n",
      "grad bias TreefyState(\n",
      "  type=<class 'brainstate.ParamState'>,\n",
      "  value=Array([-3.], dtype=float32),\n",
      "  tag=None\n",
      ")\n",
      "grad weight TreefyState(\n",
      "  type=<class 'brainstate.ParamState'>,\n",
      "  value=Array([[-2.]], dtype=float32),\n",
      "  tag=None\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Example: compute gradients w.r.t. ParamState using brainstate.transform.grad\n",
    "def mse_loss(params, others, x):\n",
    "    lin_recovered = brainstate.graph.treefy_merge(lin_graph, params, others)\n",
    "    pred = lin_recovered(x)\n",
    "    target = 2.0 * x + 1.0\n",
    "    return jnp.mean((pred - target) ** 2)\n",
    "\n",
    "loss_grad = jax.value_and_grad(mse_loss)\n",
    "\n",
    "(loss_value, grads) = loss_grad(param_tree, other_states, jnp.array([[0.0], [1.0]]))\n",
    "print('loss:', float(loss_value))\n",
    "for path, g in grads.items():\n",
    "    print('grad', path, g)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb4e5442cf6571ab",
   "metadata": {},
   "source": [
    "`treefy_states` drops directly into optimisation pipelines: you obtain a PyTree\n",
    "keyed by parameter paths without carrying the `GraphDef` unless you plan to\n",
    "reconstruct the module elsewhere."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e8df0750d1f98a6",
   "metadata": {},
   "source": [
    "## Static arguments still apply\n",
    "\n",
    "Static-argument handling mirrors `jax.jit`. The example below specialises the\n",
    "compiled program by polynomial degree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7ce5c6b92dd21e7e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T02:31:29.171041Z",
     "start_time": "2025-10-11T02:31:29.114248Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 6. 34.] [ 6. 34.] [10. 98.]\n"
     ]
    }
   ],
   "source": [
    "@brainstate.transform.jit(static_argnums=1)\n",
    "def polynomial_series(x: jax.Array, degree: int) -> jax.Array:\n",
    "    powers = [x ** (i + 1) for i in range(degree)]\n",
    "    coeffs = jnp.arange(1, degree + 1, dtype=x.dtype)\n",
    "    return jnp.tensordot(coeffs, jnp.stack(powers, axis=0), axes=1)\n",
    "\n",
    "\n",
    "p1 = polynomial_series(jnp.array([1.0, 2.0]), 3)\n",
    "p2 = polynomial_series(jnp.array([1.0, 2.0]), 3)\n",
    "p3 = polynomial_series(jnp.array([1.0, 2.0]), 4)\n",
    "print(p1, p2, p3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a9d14c53cdd50e9",
   "metadata": {},
   "source": [
    "## Which API should you choose?\n",
    "\n",
    "| Scenario | `brainstate.transform.jit` | `jax.jit` |\n",
    "| -------- | -------------------------- | --------- |\n",
    "| Stateful BrainState modules | ✅ Zero boilerplate | ⚠️ Requires `treefy_split` and manual state threading |\n",
    "| Pure stateless functions | ✅ Works (with helper methods) | ✅ Often the leanest choice |\n",
    "| Need `compile()` / `clear_cache()` | ✅ Built-in | ❌ Not available |\n",
    "| Custom sharding / device placement | ✅ Same signature as `jax.jit` | ✅ |\n",
    "\n",
    "`treefy_split` is the workhorse when you need a `GraphDef` for reconstruction or\n",
    "JAX interop. `treefy_states` is the light option for extracting parameter\n",
    "PyTrees, for example before calling `brainstate.transform.grad` or saving a\n",
    "checkpoint."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a041357a47f07cd7",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- `brainstate.transform.jit` tracks BrainState `State` objects automatically and\n",
    "  returns a `JittedFunction` with extra controls.\n",
    "- `jax.jit` still works, but you must explicitly split and merge module state.\n",
    "- `graph.treefy_split` produces `(graph_def, state_tree1, state_tree2, …)` for\n",
    "  reconstruction; `graph.treefy_states` returns just the requested state trees.\n",
    "- Choose the interface that matches your workflow: use BrainState JIT for\n",
    "  module-centric code, drop down to JAX primitives when integrating with other\n",
    "  systems."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
