{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "51d206e0",
   "metadata": {},
   "source": [
    "# Error Handling and Runtime Checks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e1965f5",
   "metadata": {},
   "source": [
    "Inside a compiled (`jit`) function, ordinary Python `assert`s and `if` statements that depend on\n",
    "array *values* do not work — the values are abstract tracers at trace time. A division by zero,\n",
    "an out-of-bounds index, or a NaN therefore propagates silently and surfaces much later as a\n",
    "meaningless result.\n",
    "\n",
    "`brainstate.transform` provides JIT-compatible runtime checks built on JAX's `checkify`\n",
    "machinery, extended to understand `State`. This tutorial covers three tools:\n",
    "\n",
    "- **`checkify`** — turn value errors into an explicit error object you can inspect.\n",
    "- **`jit_error_if`** — raise on a bad condition from inside a compiled step.\n",
    "- **`debug_nan`** — pinpoint where a NaN or Inf first appears."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8cf5cc89",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:42.367860Z",
     "iopub.status.busy": "2026-05-30T16:40:42.367703Z",
     "iopub.status.idle": "2026-05-30T16:40:44.480812Z",
     "shell.execute_reply": "2026-05-30T16:40:44.480050Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "import brainstate.transform as T\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73c30eda",
   "metadata": {},
   "source": [
    "## `checkify`: functionalize runtime checks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63cbf2e3",
   "metadata": {},
   "source": [
    "`checkify` transforms a function so that, instead of raising, it *returns* an `(error, result)`\n",
    "pair. The error object is inert until you ask about it: `error.get()` returns `None` when all\n",
    "checks passed, or the failure message otherwise. Inside the function you assert conditions with\n",
    "`check(pred, msg, *fmt_args)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e3c7760a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:44.482946Z",
     "iopub.status.busy": "2026-05-30T16:40:44.482613Z",
     "iopub.status.idle": "2026-05-30T16:40:44.703803Z",
     "shell.execute_reply": "2026-05-30T16:40:44.702885Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "valid input -> error: None\n",
      "valid input -> result: [0.        0.6931472]\n",
      "bad input   -> error: x must be positive, got [-1.  2.] (`check` failed)\n"
     ]
    }
   ],
   "source": [
    "def safe_log(x):\n",
    "    T.check(jnp.all(x > 0), 'x must be positive, got {}', x)\n",
    "    return jnp.log(x)\n",
    "\n",
    "checked = T.checkify(safe_log, errors=T.user_checks)\n",
    "\n",
    "err, out = checked(jnp.array([1.0, 2.0]))\n",
    "print('valid input -> error:', err.get())\n",
    "print('valid input -> result:', out)\n",
    "\n",
    "err, out = checked(jnp.array([-1.0, 2.0]))\n",
    "print('bad input   -> error:', err.get())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c650359d",
   "metadata": {},
   "source": [
    "To turn a captured error back into a real exception at an outer boundary, call\n",
    "`check_error(error)` — it raises if the error is set, and is a no-op otherwise."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20827bd3",
   "metadata": {},
   "source": [
    "## Built-in check categories"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3044981",
   "metadata": {},
   "source": [
    "Beyond your own `check`s, `checkify` can automatically detect whole classes of failures. Select\n",
    "them by passing a predefined set as `errors`:\n",
    "\n",
    "| Set | Detects |\n",
    "| --- | --- |\n",
    "| `user_checks` | your explicit `check(...)` assertions |\n",
    "| `nan_checks` | NaN values produced by any primitive |\n",
    "| `float_checks` | NaN **and** Inf from floating-point ops |\n",
    "| `div_checks` | division by zero |\n",
    "| `index_checks` | out-of-bounds array indexing |\n",
    "| `all_checks` | every category above |\n",
    "\n",
    "No explicit `check` call is needed — the failure is caught wherever it occurs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "da80d8fa",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:44.706022Z",
     "iopub.status.busy": "2026-05-30T16:40:44.705801Z",
     "iopub.status.idle": "2026-05-30T16:40:45.480924Z",
     "shell.execute_reply": "2026-05-30T16:40:45.480230Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nan_checks  : nan generated by primitive: log.\n",
      "div_checks  : division by zero\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "index_checks: out-of-bounds indexing for array of shape (3,): index 10 is out of bounds for axis 0 with size 3. \n"
     ]
    }
   ],
   "source": [
    "# NaN detection\n",
    "nan_checked = T.checkify(lambda x: jnp.log(x), errors=T.nan_checks)\n",
    "err, _ = nan_checked(jnp.array([-1.0]))\n",
    "print('nan_checks  :', err.get())\n",
    "\n",
    "# Division by zero\n",
    "div_checked = T.checkify(lambda a, b: a / b, errors=T.div_checks)\n",
    "err, _ = div_checked(jnp.array(1.0), jnp.array(0.0))\n",
    "print('div_checks  :', err.get())\n",
    "\n",
    "# Out-of-bounds indexing\n",
    "idx_checked = T.checkify(lambda arr, i: arr[i], errors=T.index_checks)\n",
    "err, _ = idx_checked(jnp.arange(3), 10)\n",
    "print('index_checks:', err.get())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4c5d82d",
   "metadata": {},
   "source": [
    "## `jit_error_if`: raise from inside a compiled step"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4731b77",
   "metadata": {},
   "source": [
    "When you would rather fail loudly than thread an error object around, `jit_error_if(pred, msg)`\n",
    "raises a runtime error if `pred` is true. It works inside `brainstate.transform.jit` and is the\n",
    "right tool for guarding preconditions in a training or simulation step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "755a0dda",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:45.482963Z",
     "iopub.status.busy": "2026-05-30T16:40:45.482764Z",
     "iopub.status.idle": "2026-05-30T16:40:45.524909Z",
     "shell.execute_reply": "2026-05-30T16:40:45.524013Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "valid: [0.5  0.25]\n"
     ]
    }
   ],
   "source": [
    "@brainstate.transform.jit\n",
    "def reciprocal(x):\n",
    "    T.jit_error_if(jnp.any(x == 0.0), 'reciprocal received a zero entry')\n",
    "    return 1.0 / x\n",
    "\n",
    "print('valid:', reciprocal(jnp.array([2.0, 4.0])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21a7b7f3",
   "metadata": {},
   "source": [
    "If the predicate is ever true at runtime, the call raises with your message instead of returning\n",
    "`inf`. The check compiles away to a cheap conditional, so it is safe to leave in production\n",
    "steps."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aea99456",
   "metadata": {},
   "source": [
    "## `debug_nan`: locate the first NaN or Inf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc18d9ce",
   "metadata": {},
   "source": [
    "When a model diverges, the hard part is finding *where* the NaN was born. `debug_nan(fn, *args)`\n",
    "runs `fn` with NaN/Inf detection enabled and raises — naming the offending primitive — the moment\n",
    "a non-finite value appears."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cbb8106f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:40:45.526701Z",
     "iopub.status.busy": "2026-05-30T16:40:45.526562Z",
     "iopub.status.idle": "2026-05-30T16:40:45.986227Z",
     "shell.execute_reply": "2026-05-30T16:40:45.985548Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "finite computation: OK\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "debug_nan caught: RuntimeError\n"
     ]
    }
   ],
   "source": [
    "def unstable(x):\n",
    "    y = x * 1e20\n",
    "    return jnp.exp(y)   # overflows to inf\n",
    "\n",
    "# A finite computation passes through untouched.\n",
    "T.debug_nan(lambda x: x * 2.0, jnp.array([1.0, 2.0]))\n",
    "print('finite computation: OK')\n",
    "\n",
    "try:\n",
    "    T.debug_nan(unstable, jnp.array([10.0]))\n",
    "except Exception as e:\n",
    "    print('debug_nan caught:', type(e).__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2236eba9",
   "metadata": {},
   "source": [
    "Use `debug_nan_if(has_nan, fn, *args)` to enable the (somewhat costly) detection only when an\n",
    "upstream flag already suspects trouble, keeping the fast path fast."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "636c2f2c",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- Value-dependent `assert`/`if` do not work inside `jit`; use the runtime-check tools instead.\n",
    "- **`checkify`** returns `(error, result)`; inspect with `error.get()`, assert with `check(...)`,\n",
    "  re-raise with `check_error(...)`.\n",
    "- Predefined sets — `user_checks`, `nan_checks`, `float_checks`, `div_checks`, `index_checks`,\n",
    "  `all_checks` — catch whole categories of failures automatically.\n",
    "- **`jit_error_if`** raises on a bad condition from inside a compiled step.\n",
    "- **`debug_nan`** / **`debug_nan_if`** pinpoint the primitive that first produced a NaN or Inf.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [Debugging](07_debugging.ipynb) — printing and inspecting values inside transformed code.\n",
    "- [JIT and compilation](01_jit_and_compilation.ipynb) — why value-dependent control flow is restricted."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
