{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a37a2a13",
   "metadata": {},
   "source": [
    "# IR Optimization and Code Generation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "076dde24",
   "metadata": {},
   "source": [
    "Every BrainState transformation ultimately lowers your code to a **jaxpr** — JAX's typed\n",
    "intermediate representation. BrainState exposes that IR and a small toolkit for working with it:\n",
    "inspect the computation graph, run optimization passes over it, and regenerate readable Python\n",
    "from it. These tools are useful for understanding what the compiler sees, verifying that state\n",
    "reads and writes are tracked correctly, and squeezing redundant work out of a hot path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5c6f7fb0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:55.605808Z",
     "iopub.status.busy": "2026-05-30T16:40:55.605586Z",
     "iopub.status.idle": "2026-05-30T16:40:57.739624Z",
     "shell.execute_reply": "2026-05-30T16:40:57.738690Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "import brainstate.transform as T\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "718d2612",
   "metadata": {},
   "source": [
    "## Inspecting the IR with `make_jaxpr`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5655a807",
   "metadata": {},
   "source": [
    "`make_jaxpr(fn)(args)` traces `fn` and returns its `ClosedJaxpr` together with the tuple of\n",
    "`State`s it touched. Unlike `jax.make_jaxpr`, it is state-aware: state reads appear as extra\n",
    "inputs and state writes as extra outputs of the jaxpr."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "65227882",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:57.742064Z",
     "iopub.status.busy": "2026-05-30T16:40:57.741749Z",
     "iopub.status.idle": "2026-05-30T16:40:57.763715Z",
     "shell.execute_reply": "2026-05-30T16:40:57.763095Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "states touched: 0\n",
      "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:f32[3]\u001b[39m. \u001b[34;1mlet\n",
      "    \u001b[39;22mb\u001b[35m:f32[3]\u001b[39m = integer_pow[y=2] a\n",
      "    c\u001b[35m:f32[]\u001b[39m = reduce_sum[axes=(0,) out_sharding=None] b\n",
      "  \u001b[34;1min \u001b[39;22m(c,) }\n"
     ]
    }
   ],
   "source": [
    "def pure(x):\n",
    "    return jnp.sum(x ** 2)\n",
    "\n",
    "jaxpr, states = T.make_jaxpr(pure)(jnp.array([1.0, 2.0, 3.0]))\n",
    "print('states touched:', len(states))\n",
    "print(jaxpr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1220858",
   "metadata": {},
   "source": [
    "With a stateful function the difference is visible: the state value enters as an input and the\n",
    "updated value leaves as an output, making the data flow through `State` explicit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "108752ab",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:57.765745Z",
     "iopub.status.busy": "2026-05-30T16:40:57.765560Z",
     "iopub.status.idle": "2026-05-30T16:40:57.782656Z",
     "shell.execute_reply": "2026-05-30T16:40:57.781672Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "states touched: 1\n",
      "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:f32[]\u001b[39m b\u001b[35m:f32[]\u001b[39m. \u001b[34;1mlet\n",
      "    \u001b[39;22mc\u001b[35m:f32[]\u001b[39m = add b a\n",
      "    d\u001b[35m:f32[]\u001b[39m = mul c 2.0:f32[]\n",
      "  \u001b[34;1min \u001b[39;22m(d, c) }\n"
     ]
    }
   ],
   "source": [
    "counter = brainstate.State(jnp.array(0.0))\n",
    "\n",
    "def stateful(x):\n",
    "    counter.value = counter.value + x\n",
    "    return counter.value * 2\n",
    "\n",
    "jaxpr, states = T.make_jaxpr(stateful)(jnp.array(1.0))\n",
    "print('states touched:', len(states))\n",
    "print(jaxpr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b54fc386",
   "metadata": {},
   "source": [
    "## Optimizing a jaxpr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc0e16dd",
   "metadata": {},
   "source": [
    "`optimize_jaxpr` applies classic compiler passes to a jaxpr. The available passes are:\n",
    "\n",
    "| Pass | Effect |\n",
    "| --- | --- |\n",
    "| `dce` | dead-code elimination — drop equations whose outputs are unused |\n",
    "| `cse` | common-subexpression elimination — reuse repeated computations |\n",
    "| `constant_fold` | evaluate operations on known constants ahead of time |\n",
    "| `algebraic_simplification` | apply identities such as `x * 1 -> x` |\n",
    "| `copy_propagation` | remove redundant copies |\n",
    "\n",
    "The function below contains deliberate waste: an unused product and a multiply-by-one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ba94b42d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:57.784679Z",
     "iopub.status.busy": "2026-05-30T16:40:57.784457Z",
     "iopub.status.idle": "2026-05-30T16:40:57.804283Z",
     "shell.execute_reply": "2026-05-30T16:40:57.803527Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "equations: 4 -> 2\n",
      "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:f32[2]\u001b[39m. \u001b[34;1mlet\n",
      "    \u001b[39;22mb\u001b[35m:f32[2]\u001b[39m = add a 1.0:f32[]\n",
      "    c\u001b[35m:f32[2]\u001b[39m = add b b\n",
      "  \u001b[34;1min \u001b[39;22m(c,) }\n"
     ]
    }
   ],
   "source": [
    "def wasteful(x):\n",
    "    a = x + 1.0\n",
    "    unused = x * 999.0      # dead code\n",
    "    scaled = a * 1.0        # algebraic identity\n",
    "    return scaled + a\n",
    "\n",
    "jaxpr, _ = T.make_jaxpr(wasteful)(jnp.array([1.0, 2.0]))\n",
    "before = len(jaxpr.jaxpr.eqns)\n",
    "\n",
    "optimized = T.optimize_jaxpr(jaxpr, optimizations=['dce', 'algebraic_simplification', 'cse'])\n",
    "after = len(optimized.jaxpr.eqns)\n",
    "\n",
    "print(f'equations: {before} -> {after}')\n",
    "print(optimized)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bf0c19b",
   "metadata": {},
   "source": [
    "## Generating Python from a jaxpr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19136714",
   "metadata": {},
   "source": [
    "`jaxpr_to_python_code` turns a jaxpr back into readable Python source — a useful way to *see* the\n",
    "effect of an optimization pass. `fn_to_python_code` is the one-step convenience that traces a\n",
    "function and prints its generated code directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "965e4526",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:57.806333Z",
     "iopub.status.busy": "2026-05-30T16:40:57.806076Z",
     "iopub.status.idle": "2026-05-30T16:40:57.810620Z",
     "shell.execute_reply": "2026-05-30T16:40:57.809813Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def optimized_fn(a):\n",
      "    b = a + 1.0\n",
      "    c = b + b\n",
      "    return c\n"
     ]
    }
   ],
   "source": [
    "print(T.jaxpr_to_python_code(optimized.jaxpr, fn_name='optimized_fn'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "edec5eb7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:57.812766Z",
     "iopub.status.busy": "2026-05-30T16:40:57.812513Z",
     "iopub.status.idle": "2026-05-30T16:40:57.817261Z",
     "shell.execute_reply": "2026-05-30T16:40:57.816424Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def pure(a):\n",
      "    b = jax.lax.integer_pow(a, 2)\n",
      "    c = jax.numpy.sum(b, axis=(0,))\n",
      "    return c\n"
     ]
    }
   ],
   "source": [
    "print(T.fn_to_python_code(pure, jnp.array([1.0, 2.0, 3.0])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "229fc6aa",
   "metadata": {},
   "source": [
    "## `StatefulFunction`: the engine behind the transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1589d594",
   "metadata": {},
   "source": [
    "`StatefulFunction` is the lower-level wrapper that every state-aware transform builds on. It\n",
    "traces a function once and then answers precise questions about it: which states it *reads*,\n",
    "which it *writes*, and what its jaxpr looks like. Construct it with `ir_optimizations` to apply\n",
    "optimization passes automatically during tracing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d1cec563",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:57.819327Z",
     "iopub.status.busy": "2026-05-30T16:40:57.819108Z",
     "iopub.status.idle": "2026-05-30T16:40:57.877738Z",
     "shell.execute_reply": "2026-05-30T16:40:57.877067Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "read states : 0\n",
      "write states: 1\n",
      "jaxpr call result: 6.0\n"
     ]
    }
   ],
   "source": [
    "sf = T.StatefulFunction(stateful, ir_optimizations=['dce', 'cse'])\n",
    "sf.make_jaxpr(jnp.array(1.0))\n",
    "\n",
    "print('read states :', len(sf.get_read_states(jnp.array(1.0))))\n",
    "print('write states:', len(sf.get_write_states(jnp.array(1.0))))\n",
    "\n",
    "# Execute through the traced jaxpr, automatically threading state in and out.\n",
    "counter.value = jnp.array(0.0)\n",
    "print('jaxpr call result:', float(sf.jaxpr_call_auto(jnp.array(3.0))))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "440f66df",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- **`make_jaxpr`** exposes the state-aware IR: reads become inputs, writes become outputs.\n",
    "- **`optimize_jaxpr`** runs `dce`, `cse`, `constant_fold`, `algebraic_simplification`, and\n",
    "  `copy_propagation` passes to shrink a jaxpr.\n",
    "- **`jaxpr_to_python_code`** / **`fn_to_python_code`** regenerate readable Python from the IR.\n",
    "- **`StatefulFunction`** is the underlying primitive: it reports read/write states and can apply\n",
    "  IR optimizations as it traces.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [JIT and compilation](01_jit_and_compilation.ipynb) — how tracing and caching drive compilation.\n",
    "- [Debugging](07_debugging.ipynb) — inspecting values rather than structure."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
