{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4b1e0683",
   "metadata": {},
   "source": [
    "# Observe and Intercept State Access with Hooks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eea09994",
   "metadata": {},
   "source": [
    "State hooks let you run a callback whenever a `State` is read, written, restored, or created —\n",
    "*without* editing the model that owns the state. They are the mechanism for cross-cutting\n",
    "concerns: logging value changes, validating writes, enforcing invariants, or tracing access\n",
    "patterns while debugging.\n",
    "\n",
    "A hook is registered against one of five operations:\n",
    "\n",
    "| Operation | Fires when | Can modify / cancel? |\n",
    "|---|---|---|\n",
    "| `read` | `state.value` is read | no (inspect only) |\n",
    "| `write_before` | just before `state.value = ...` | **yes** — transform or cancel the write |\n",
    "| `write_after` | just after a write completes | no (inspect only) |\n",
    "| `restore` | `state.restore_value(...)` is called | no |\n",
    "| `init` | a `State` is constructed | no |\n",
    "\n",
    "Hooks come in two scopes. A **global** hook (`brainstate.register_state_hook`) fires for every\n",
    "state in the program; a **per-state** hook (`state.register_hook`) fires only for that one\n",
    "instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c3074b0e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:10:44.355940Z",
     "iopub.status.busy": "2026-05-30T17:10:44.355726Z",
     "iopub.status.idle": "2026-05-30T17:10:46.491710Z",
     "shell.execute_reply": "2026-05-30T17:10:46.491098Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7315e20f",
   "metadata": {},
   "source": [
    "All examples below clear the global registry first so they run independently. In real code you\n",
    "rarely need to clear it — you register once at start-up and remove handles when done."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44038b12",
   "metadata": {},
   "source": [
    "## Logging writes with a `write_after` hook"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55c95dde",
   "metadata": {},
   "source": [
    "The most common use is read-only observation. A `write_after` hook receives a context carrying\n",
    "the state, the `old_value` it held, and the new `value` just written. Here we record every\n",
    "change to a named state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eddf23d1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:10:46.494140Z",
     "iopub.status.busy": "2026-05-30T17:10:46.493781Z",
     "iopub.status.idle": "2026-05-30T17:10:46.515583Z",
     "shell.execute_reply": "2026-05-30T17:10:46.514727Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('weight', 1.0, 2.0), ('weight', 2.0, 2.5)]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "brainstate.clear_state_hooks()\n",
    "\n",
    "history = []\n",
    "\n",
    "def record(ctx):\n",
    "    history.append((ctx.state_name, float(ctx.old_value), float(ctx.value)))\n",
    "\n",
    "handle = brainstate.register_state_hook('write_after', record)\n",
    "\n",
    "weight = brainstate.State(jnp.array(1.0), name='weight')\n",
    "weight.value = jnp.array(2.0)\n",
    "weight.value = jnp.array(2.5)\n",
    "\n",
    "history"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6aa3af26",
   "metadata": {},
   "source": [
    "The context object also exposes `operation`, `timestamp`, and a `metadata` dict you can use to\n",
    "pass information between hooks. `ctx.state` is a *weak* reference target — it returns `None` if\n",
    "the state has already been garbage-collected — so guard against that in long-lived hooks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4b9eb10",
   "metadata": {},
   "source": [
    "## Enforcing a constraint with `write_before`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d09354d0",
   "metadata": {},
   "source": [
    "A `write_before` hook runs *before* the value is stored and may rewrite it. Set\n",
    "`ctx.transformed_value` to substitute a new value; if several `write_before` hooks are\n",
    "registered they chain in priority order, each seeing the previous hook's output. This keeps a\n",
    "parameter inside a valid range no matter who writes to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "72895ce9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:10:46.517607Z",
     "iopub.status.busy": "2026-05-30T17:10:46.517425Z",
     "iopub.status.idle": "2026-05-30T17:10:46.549015Z",
     "shell.execute_reply": "2026-05-30T17:10:46.548255Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "after writing 5.0: 1.0\n",
      "after writing -3.0: -1.0\n"
     ]
    }
   ],
   "source": [
    "brainstate.clear_state_hooks()\n",
    "\n",
    "def clip_to_unit(ctx):\n",
    "    current = ctx.transformed_value if ctx.transformed_value is not None else ctx.value\n",
    "    ctx.transformed_value = jnp.clip(current, -1.0, 1.0)\n",
    "\n",
    "brainstate.register_state_hook('write_before', clip_to_unit)\n",
    "\n",
    "gate = brainstate.State(jnp.array(0.0))\n",
    "gate.value = jnp.array(5.0)    # clipped to 1.0\n",
    "print('after writing 5.0:', float(gate.value))\n",
    "gate.value = jnp.array(-3.0)   # clipped to -1.0\n",
    "print('after writing -3.0:', float(gate.value))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04799e62",
   "metadata": {},
   "source": [
    "## Rejecting an invalid write"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff080abb",
   "metadata": {},
   "source": [
    "A `write_before` hook can also *cancel* a write by setting `ctx.cancel = True`. The assignment\n",
    "raises `HookCancellationError` and the state keeps its previous value — useful for guarding an\n",
    "invariant that should never be silently violated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "65d36017",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:10:46.550951Z",
     "iopub.status.busy": "2026-05-30T17:10:46.550770Z",
     "iopub.status.idle": "2026-05-30T17:10:46.595063Z",
     "shell.execute_reply": "2026-05-30T17:10:46.594189Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "write rejected: hook_2: value must be non-negative\n",
      "value unchanged: 0.5\n"
     ]
    }
   ],
   "source": [
    "brainstate.clear_state_hooks()\n",
    "\n",
    "def reject_negative(ctx):\n",
    "    value = ctx.transformed_value if ctx.transformed_value is not None else ctx.value\n",
    "    if jnp.any(value < 0):\n",
    "        ctx.cancel = True\n",
    "        ctx.cancel_reason = 'value must be non-negative'\n",
    "\n",
    "brainstate.register_state_hook('write_before', reject_negative)\n",
    "\n",
    "rate = brainstate.State(jnp.array(0.5))\n",
    "try:\n",
    "    rate.value = jnp.array(-0.1)\n",
    "except brainstate.HookCancellationError as err:\n",
    "    print('write rejected:', err)\n",
    "print('value unchanged:', float(rate.value))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f2380c5",
   "metadata": {},
   "source": [
    "## Scoping a hook to one state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8efdfd59",
   "metadata": {},
   "source": [
    "`state.register_hook` attaches the callback to a single instance. Other states are unaffected,\n",
    "which is the right tool when only one buffer needs special treatment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ed1089ff",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:10:46.597194Z",
     "iopub.status.busy": "2026-05-30T17:10:46.596965Z",
     "iopub.status.idle": "2026-05-30T17:10:46.602643Z",
     "shell.execute_reply": "2026-05-30T17:10:46.601662Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "writes seen by the per-state hook: [1.0, 2.0]\n"
     ]
    }
   ],
   "source": [
    "brainstate.clear_state_hooks()\n",
    "\n",
    "watched = brainstate.State(jnp.array(0.0), name='watched')\n",
    "other = brainstate.State(jnp.array(0.0), name='other')\n",
    "\n",
    "seen = []\n",
    "watched.register_hook('write_after', lambda ctx: seen.append(float(ctx.value)))\n",
    "\n",
    "watched.value = jnp.array(1.0)\n",
    "other.value = jnp.array(99.0)   # not observed\n",
    "watched.value = jnp.array(2.0)\n",
    "\n",
    "print('writes seen by the per-state hook:', seen)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f180ae7e",
   "metadata": {},
   "source": [
    "## Managing hook handles"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ac9b7c4",
   "metadata": {},
   "source": [
    "Every registration returns a `HookHandle`. Use it to temporarily silence a hook, re-enable it,\n",
    "or remove it permanently. This is how you bound the lifetime of a debugging hook to a single\n",
    "section of code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6ea231f1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:10:46.604871Z",
     "iopub.status.busy": "2026-05-30T17:10:46.604611Z",
     "iopub.status.idle": "2026-05-30T17:10:46.610027Z",
     "shell.execute_reply": "2026-05-30T17:10:46.609336Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "recorded writes: [1.0, 3.0]\n",
      "handle removed? True\n"
     ]
    }
   ],
   "source": [
    "brainstate.clear_state_hooks()\n",
    "\n",
    "calls = []\n",
    "handle = brainstate.register_state_hook('write_after', lambda ctx: calls.append(float(ctx.value)))\n",
    "s = brainstate.State(jnp.array(0.0))\n",
    "\n",
    "s.value = jnp.array(1.0)   # recorded\n",
    "handle.disable()\n",
    "s.value = jnp.array(2.0)   # skipped while disabled\n",
    "handle.enable()\n",
    "s.value = jnp.array(3.0)   # recorded\n",
    "handle.remove()\n",
    "s.value = jnp.array(4.0)   # hook gone\n",
    "\n",
    "print('recorded writes:', calls)\n",
    "print('handle removed?', handle.is_removed())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2de28a2d",
   "metadata": {},
   "source": [
    "For introspection, `brainstate.list_state_hooks()` returns the registered hooks (optionally\n",
    "filtered by type), `has_state_hooks()` reports whether any are active, and `clear_state_hooks()`\n",
    "removes them all."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d6d151a9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:10:46.612320Z",
     "iopub.status.busy": "2026-05-30T17:10:46.612063Z",
     "iopub.status.idle": "2026-05-30T17:10:46.616660Z",
     "shell.execute_reply": "2026-05-30T17:10:46.615657Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hooks registered: False\n",
      "after clear: False\n"
     ]
    }
   ],
   "source": [
    "print('hooks registered:', brainstate.has_state_hooks())\n",
    "brainstate.clear_state_hooks()\n",
    "print('after clear:', brainstate.has_state_hooks())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3af6125",
   "metadata": {},
   "source": [
    "## Hooks and compiled code"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57581e0d",
   "metadata": {},
   "source": [
    "Hooks are ordinary Python callbacks, so they fire on every concrete `.value` access. Inside a\n",
    "`brainstate.transform.jit` step they still fire once per call at run time — and additionally\n",
    "once during the initial trace, where `ctx.value` is an abstract tracer rather than a concrete\n",
    "array. Keep hook bodies free of Python branching on a value's *contents* (e.g. `if float(...)`)\n",
    "so they behave correctly during tracing. For checks that must live *inside* compiled code — NaN\n",
    "guards, bounds assertions — prefer the dedicated error-handling tools\n",
    "(`brainstate.transform.checkify`, `check`, and `debug_nan`), which are designed to run under\n",
    "transformation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5ba24af",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- Hooks observe or intercept `State` operations — `read`, `write_before`, `write_after`,\n",
    "  `restore`, `init` — without modifying model code.\n",
    "- `register_state_hook` registers globally; `state.register_hook` scopes to one instance.\n",
    "- A `write_before` hook can **transform** a value via `ctx.transformed_value` or **cancel** the\n",
    "  write via `ctx.cancel`, which raises `HookCancellationError`.\n",
    "- Registration returns a `HookHandle` with `disable` / `enable` / `remove`; inspect the registry\n",
    "  with `list_state_hooks` / `has_state_hooks` and reset it with `clear_state_hooks`.\n",
    "- Hooks are eager Python callbacks; for checks inside compiled code use\n",
    "  `brainstate.transform.checkify` and friends.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [Constrain and regularize parameters](constrain_and_regularize_parameters.ipynb) — a declarative alternative to write hooks for keeping parameters in range.\n",
    "- [Error handling and validation](../tutorials/transformations/06_error_handling_and_checks.ipynb) — checks that run inside `jit`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
