{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "003e540f",
   "metadata": {},
   "source": [
    "# Debugging Transformed Code"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60148182",
   "metadata": {},
   "source": [
    "A plain `print(x)` inside a `jit`-compiled function runs only *once*, at trace time, and shows a\n",
    "tracer rather than a value. To observe actual values — and to set breakpoints — you need tools\n",
    "that execute at *runtime*, after compilation. This tutorial covers the practical debugging\n",
    "workflow for BrainState code:\n",
    "\n",
    "- `jax.debug.print` for value-time printing, including inside `grad` and `vmap`;\n",
    "- `jax.debug.callback` for richer inspection (shapes, statistics);\n",
    "- `brainstate.transform.breakpoint_if` for conditional breakpoints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6caa164c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:47.560140Z",
     "iopub.status.busy": "2026-05-30T16:40:47.559973Z",
     "iopub.status.idle": "2026-05-30T16:40:49.709385Z",
     "shell.execute_reply": "2026-05-30T16:40:49.708351Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "import brainstate.transform as T\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1be6ab7",
   "metadata": {},
   "source": [
    "## `jax.debug.print`: values at runtime"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "139acd67",
   "metadata": {},
   "source": [
    "`jax.debug.print` is compilation-safe: it defers to runtime and prints the concrete value each\n",
    "time the function executes. Use `{name}` placeholders filled by keyword arguments. Note the\n",
    "prints appear when the function *runs*, not when it is traced."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3c477249",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:49.711967Z",
     "iopub.status.busy": "2026-05-30T16:40:49.711615Z",
     "iopub.status.idle": "2026-05-30T16:40:49.772001Z",
     "shell.execute_reply": "2026-05-30T16:40:49.771262Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input      = [1. 2. 3.]\n",
      "after square = [1. 4. 9.]\n",
      "returned: 14.0\n"
     ]
    }
   ],
   "source": [
    "@brainstate.transform.jit\n",
    "def compute(x):\n",
    "    jax.debug.print('input      = {x}', x=x)\n",
    "    y = x ** 2\n",
    "    jax.debug.print('after square = {y}', y=y)\n",
    "    return jnp.sum(y)\n",
    "\n",
    "total = compute(jnp.array([1.0, 2.0, 3.0]))\n",
    "print('returned:', float(total))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2b887eb",
   "metadata": {},
   "source": [
    "## Inspecting state updates"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c1e7288",
   "metadata": {},
   "source": [
    "Because the prints execute at runtime, they see the real `State` values before and after a\n",
    "mutation — invaluable when a buffer drifts or fails to update."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4260c7d3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:49.773702Z",
     "iopub.status.busy": "2026-05-30T16:40:49.773553Z",
     "iopub.status.idle": "2026-05-30T16:40:49.845734Z",
     "shell.execute_reply": "2026-05-30T16:40:49.844763Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "before: [0. 0. 0.]\n",
      "after : [1. 2. 3.]\n"
     ]
    }
   ],
   "source": [
    "class Accumulator(brainstate.nn.Module):\n",
    "    def __init__(self, size):\n",
    "        super().__init__()\n",
    "        self.total = brainstate.ShortTermState(jnp.zeros(size))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        jax.debug.print('before: {s}', s=self.total.value)\n",
    "        self.total.value = self.total.value + x\n",
    "        jax.debug.print('after : {s}', s=self.total.value)\n",
    "        return self.total.value\n",
    "\n",
    "acc = Accumulator(3)\n",
    "step = brainstate.transform.jit(acc)\n",
    "_ = step(jnp.array([1.0, 2.0, 3.0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30c17fbf",
   "metadata": {},
   "source": [
    "## Debugging inside `grad`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "331eb3b5",
   "metadata": {},
   "source": [
    "Prints placed in a loss function fire during the forward pass of differentiation, letting you\n",
    "watch the quantities that feed the gradient."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6fc36ca1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:49.847876Z",
     "iopub.status.busy": "2026-05-30T16:40:49.847700Z",
     "iopub.status.idle": "2026-05-30T16:40:50.002322Z",
     "shell.execute_reply": "2026-05-30T16:40:50.001408Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prediction = [1. 3.]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "grad: [1. 6.]\n"
     ]
    }
   ],
   "source": [
    "weight = brainstate.ParamState(jnp.array([2.0, 3.0]))\n",
    "\n",
    "def loss_fn(x):\n",
    "    pred = weight.value * x\n",
    "    jax.debug.print('prediction = {p}', p=pred)\n",
    "    return jnp.sum(pred ** 2)\n",
    "\n",
    "grads = brainstate.transform.grad(loss_fn, {'w': weight})(jnp.array([0.5, 1.0]))\n",
    "print('grad:', grads['w'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca45afe0",
   "metadata": {},
   "source": [
    "## Debugging inside `vmap`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42fe531a",
   "metadata": {},
   "source": [
    "Under `vmap` the print runs once per batch element, so you can confirm exactly what each lane\n",
    "receives."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8a6b104a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:50.004010Z",
     "iopub.status.busy": "2026-05-30T16:40:50.003869Z",
     "iopub.status.idle": "2026-05-30T16:40:50.285195Z",
     "shell.execute_reply": "2026-05-30T16:40:50.284422Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lane 0: x=1.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lane 0: x=1.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lane 1: x=2.0\n",
      "lane 2: x=3.0\n",
      "outputs: [1. 4. 9.]\n"
     ]
    }
   ],
   "source": [
    "def process(x, index):\n",
    "    jax.debug.print('lane {i}: x={x}', i=index, x=x)\n",
    "    return x ** 2\n",
    "\n",
    "batched = brainstate.transform.vmap(process, in_axes=(0, 0))\n",
    "out = batched(jnp.array([1.0, 2.0, 3.0]), jnp.arange(3))\n",
    "print('outputs:', out)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c873a8e0",
   "metadata": {},
   "source": [
    "## Richer inspection with `jax.debug.callback`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f7e6a1f",
   "metadata": {},
   "source": [
    "When a one-line print is not enough, `jax.debug.callback` hands the runtime values to an\n",
    "arbitrary Python function — ideal for logging summary statistics without leaving the compiled\n",
    "region. (The callback must not return a value used by the computation.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0d03a37d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:50.286906Z",
     "iopub.status.busy": "2026-05-30T16:40:50.286620Z",
     "iopub.status.idle": "2026-05-30T16:40:50.711375Z",
     "shell.execute_reply": "2026-05-30T16:40:50.710646Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[activations] shape=(100,) min=-2.442 max=2.130 mean=-0.210\n"
     ]
    }
   ],
   "source": [
    "def summarize(name, value):\n",
    "    print(f'[{name}] shape={value.shape} '\n",
    "          f'min={float(jnp.min(value)):.3f} '\n",
    "          f'max={float(jnp.max(value)):.3f} '\n",
    "          f'mean={float(jnp.mean(value)):.3f}')\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def forward(x):\n",
    "    jax.debug.callback(summarize, 'activations', x)\n",
    "    return jnp.tanh(x)\n",
    "\n",
    "_ = forward(brainstate.random.randn(100))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ebc63b6",
   "metadata": {},
   "source": [
    "## Conditional breakpoints with `breakpoint_if`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bfef537",
   "metadata": {},
   "source": [
    "`breakpoint_if(pred)` drops into JAX's interactive debugger — but only when `pred` is true at\n",
    "runtime. This lets you halt on a rare bad condition (a NaN, a negative value) without stopping on\n",
    "every iteration. Here the predicate is never satisfied, so execution proceeds normally; in a real\n",
    "session you would set it to your suspected failure condition and inspect the live values when it\n",
    "triggers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4acf8c08",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:50.713194Z",
     "iopub.status.busy": "2026-05-30T16:40:50.713024Z",
     "iopub.status.idle": "2026-05-30T16:40:51.564506Z",
     "shell.execute_reply": "2026-05-30T16:40:51.563910Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "clean input proceeds: [2. 4. 6.]\n"
     ]
    }
   ],
   "source": [
    "@brainstate.transform.jit\n",
    "def guarded(x):\n",
    "    # Pause for inspection only if a non-finite value appears.\n",
    "    T.breakpoint_if(jnp.any(~jnp.isfinite(x)))\n",
    "    return x * 2.0\n",
    "\n",
    "print('clean input proceeds:', guarded(jnp.array([1.0, 2.0, 3.0])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae08562b",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- Ordinary `print` runs at trace time and shows tracers; use runtime-aware tools instead.\n",
    "- **`jax.debug.print`** prints concrete values during execution — inside `jit`, `grad`, and\n",
    "  `vmap` alike, including before/after `State` mutations.\n",
    "- **`jax.debug.callback`** sends values to any Python function for richer inspection.\n",
    "- **`breakpoint_if(pred)`** opens an interactive debugger only when a condition is met.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [Error handling and runtime checks](06_error_handling_and_checks.ipynb) — catching NaNs and bad inputs.\n",
    "- [IR optimization and code generation](08_ir_optimization_and_codegen.ipynb) — inspecting the compiled jaxpr."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
