{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Other Transforms\n",
    "\n",
    "This tutorial covers essential utilities in `brainstate.transform` for introspection, optimization, and debugging:\n",
    "\n",
    "1. **`checkpoint`**: Memory-efficient gradient computation through rematerialization\n",
    "2. **`make_jaxpr` and `StatefulFunction`**: Inspect and understand compiled computation graphs\n",
    "3. **`jax.debug.print`**: Runtime debugging in JIT-compiled code\n",
    "\n",
    "All examples demonstrate state-aware features that distinguish BrainState from vanilla JAX."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:50.200650Z",
     "start_time": "2025-10-11T07:36:48.818176Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "from brainstate.transform import checkpoint, make_jaxpr, StatefulFunction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. `checkpoint`: Memory-Efficient Gradient Computation\n",
    "\n",
    "`checkpoint` (also known as rematerialization or gradient checkpointing) is crucial for training deep neural networks and processing long sequences. It trades computation for memory during backpropagation.\n",
    "\n",
    "### How Gradient Computation Works\n",
    "\n",
    "**Without checkpointing:**\n",
    "- Forward pass: Computes outputs and stores **all intermediate activations**\n",
    "- Backward pass: Uses stored activations to compute gradients\n",
    "- Memory usage: O(n) where n is the number of layers/steps\n",
    "\n",
    "**With checkpointing:**\n",
    "- Forward pass: Computes outputs, stores **only inputs** at checkpoints\n",
    "- Backward pass: **Recomputes** intermediate activations from checkpoints as needed\n",
    "- Memory usage: O(√n) with optimal checkpointing\n",
    "- Computation: ~2x forward passes (recomputation during backward)\n",
    "\n",
    "**Key principle: Trade extra computation for reduced memory**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 Basic Usage with Gradient Computation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:50.815357Z",
     "start_time": "2025-10-11T07:36:50.213823Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Example 1: Basic checkpoint usage ===\n",
      "Values match: True\n",
      "Gradients match: True\n",
      "\n",
      "Memory: checkpoint saves ~3x intermediate activations\n",
      "Cost: checkpoint does ~2x forward computations\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Memory-efficient gradient computation\n",
    "print(\"=== Example 1: Basic checkpoint usage ===\")\n",
    "\n",
    "# Without checkpoint: stores all intermediate activations\n",
    "def expensive_forward(x):\n",
    "    \"\"\"Chain of expensive operations.\"\"\"\n",
    "    y = jnp.sin(x)\n",
    "    z = jnp.exp(y)\n",
    "    w = jnp.tanh(z)\n",
    "    return jnp.sum(w ** 2)\n",
    "\n",
    "# With checkpoint: only stores inputs, recomputes during backward\n",
    "@checkpoint\n",
    "def checkpointed_forward(x):\n",
    "    \"\"\"Same computation, but memory-efficient.\"\"\"\n",
    "    y = jnp.sin(x)\n",
    "    z = jnp.exp(y)\n",
    "    w = jnp.tanh(z)\n",
    "    return jnp.sum(w ** 2)\n",
    "\n",
    "x = jnp.linspace(0, 10, 1000)\n",
    "\n",
    "# Both produce same results\n",
    "value1, grad1 = jax.value_and_grad(expensive_forward)(x)\n",
    "value2, grad2 = jax.value_and_grad(checkpointed_forward)(x)\n",
    "\n",
    "print(f\"Values match: {jnp.allclose(value1, value2)}\")\n",
    "print(f\"Gradients match: {jnp.allclose(grad1, grad2)}\")\n",
    "print(f\"\\nMemory: checkpoint saves ~3x intermediate activations\")\n",
    "print(f\"Cost: checkpoint does ~2x forward computations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 Checkpointing Stateful Computations\n",
    "\n",
    "BrainState's `checkpoint` properly handles `State` objects during gradient computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:36:52.293641Z",
     "start_time": "2025-10-11T07:36:50.823949Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 2: Checkpointed neural network ===\n",
      "Number of layers: 10\n",
      "Gradient shapes match: {('layers', 0): True, ('layers', 1): True, ('layers', 2): True, ('layers', 3): True, ('layers', 4): True, ('layers', 5): True, ('layers', 6): True, ('layers', 7): True, ('layers', 8): True, ('layers', 9): True}\n",
      "\n",
      "Without checkpoint: Stores ~10 layer activations\n",
      "With checkpoint: Recomputes activations during backward\n",
      "Memory saved: ~10x for deep networks\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Checkpoint with stateful neural network\n",
    "print(\"\\n=== Example 2: Checkpointed neural network ===\")\n",
    "\n",
    "class DeepNetwork(brainstate.nn.Module):\n",
    "    \"\"\"Deep network with many layers.\"\"\"\n",
    "    def __init__(self, layer_sizes):\n",
    "        super().__init__()\n",
    "        self.layers = []\n",
    "        for i in range(len(layer_sizes) - 1):\n",
    "            self.layers.append(\n",
    "                brainstate.ParamState(jax.random.normal(\n",
    "                    jax.random.PRNGKey(i), \n",
    "                    (layer_sizes[i], layer_sizes[i+1])\n",
    "                ))\n",
    "            )\n",
    "    \n",
    "    def forward(self, x, use_checkpoint=False):\n",
    "        \"\"\"Forward pass through all layers.\"\"\"\n",
    "        def layer_fn(x):\n",
    "            h = x\n",
    "            for W in self.layers[:-1]:\n",
    "                h = jnp.tanh(h @ W.value)\n",
    "            # Output layer (no activation)\n",
    "            return h @ self.layers[-1].value\n",
    "        \n",
    "        if use_checkpoint:\n",
    "            return checkpoint(layer_fn)(x)\n",
    "        else:\n",
    "            return layer_fn(x)\n",
    "\n",
    "# Create a deep network: 10 layers\n",
    "net = DeepNetwork([128, 256, 256, 256, 256, 256, 256, 256, 256, 128, 10])\n",
    "x_batch = jax.random.normal(jax.random.PRNGKey(42), (32, 128))\n",
    "\n",
    "# Define loss function\n",
    "def loss_fn(use_checkpoint):\n",
    "    y_pred = net.forward(x_batch, use_checkpoint=use_checkpoint)\n",
    "    return jnp.mean(y_pred ** 2)\n",
    "\n",
    "# Get parameters\n",
    "params = net.states(brainstate.ParamState)\n",
    "\n",
    "# Compute gradients with and without checkpoint\n",
    "grads_normal = brainstate.transform.grad(lambda: loss_fn(False), params)()\n",
    "grads_checkpointed = brainstate.transform.grad(lambda: loss_fn(True), params)()\n",
    "\n",
    "# Compare\n",
    "print(f\"Number of layers: {len(net.layers)}\")\n",
    "print(f\"Gradient shapes match: {jax.tree.map(lambda a, b: a.shape == b.shape, grads_normal, grads_checkpointed)}\")\n",
    "print(f\"\\nWithout checkpoint: Stores ~10 layer activations\")\n",
    "print(f\"With checkpoint: Recomputes activations during backward\")\n",
    "print(f\"Memory saved: ~10x for deep networks\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.3 Sequential Layer Checkpointing\n",
    "\n",
    "For very deep networks, checkpoint individual layers or groups of layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:17.789118Z",
     "start_time": "2025-10-11T07:37:17.181142Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 3: Granular checkpointing ===\n",
      "Network depth: 6 layers\n",
      "Checkpoint frequency: every 2 layers\n",
      "Checkpoints created: 3\n",
      "Loss: 85.1143\n",
      "\n",
      "Memory usage: O(checkpoints) instead of O(layers)\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Per-layer checkpointing\n",
    "print(\"\\n=== Example 3: Granular checkpointing ===\")\n",
    "\n",
    "class CheckpointedDeepNetwork(brainstate.nn.Module):\n",
    "    \"\"\"Network with per-layer checkpointing.\"\"\"\n",
    "    def __init__(self, layer_sizes, checkpoint_every=2):\n",
    "        super().__init__()\n",
    "        self.checkpoint_every = checkpoint_every\n",
    "        self.weights = []\n",
    "        for i in range(len(layer_sizes) - 1):\n",
    "            self.weights.append(\n",
    "                brainstate.ParamState(jax.random.normal(\n",
    "                    jax.random.PRNGKey(i), \n",
    "                    (layer_sizes[i], layer_sizes[i+1])\n",
    "                ) * 0.1)\n",
    "            )\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        h = x\n",
    "        for i, W in enumerate(self.weights):\n",
    "            # Define layer computation\n",
    "            def layer_forward(h):\n",
    "                return jnp.tanh(h @ W.value)\n",
    "            \n",
    "            # Checkpoint every N layers\n",
    "            if (i + 1) % self.checkpoint_every == 0:\n",
    "                h = checkpoint(layer_forward)(h)\n",
    "            else:\n",
    "                h = layer_forward(h)\n",
    "        return h\n",
    "\n",
    "# Create network: checkpoint every 2 layers\n",
    "ckpt_net = CheckpointedDeepNetwork(\n",
    "    [64, 128, 128, 128, 128, 128, 32],  # 6 layers\n",
    "    checkpoint_every=2\n",
    ")\n",
    "\n",
    "x_in = jax.random.normal(jax.random.PRNGKey(123), (16, 64))\n",
    "\n",
    "# Forward and backward\n",
    "def forward_loss():\n",
    "    return jnp.sum(ckpt_net(x_in) ** 2)\n",
    "\n",
    "grads, value = brainstate.transform.grad(\n",
    "    forward_loss, \n",
    "    ckpt_net.states(brainstate.ParamState),\n",
    "    return_value=True\n",
    ")()\n",
    "\n",
    "print(f\"Network depth: {len(ckpt_net.weights)} layers\")\n",
    "print(f\"Checkpoint frequency: every {ckpt_net.checkpoint_every} layers\")\n",
    "print(f\"Checkpoints created: {len(ckpt_net.weights) // ckpt_net.checkpoint_every}\")\n",
    "print(f\"Loss: {value:.4f}\")\n",
    "print(f\"\\nMemory usage: O(checkpoints) instead of O(layers)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.4 Memory-Computation Tradeoff\n",
    "\n",
    "Understand when to use checkpointing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:41.567790Z",
     "start_time": "2025-10-11T07:37:41.081735Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 5: When to use checkpoint ===\n",
      "Small network (3 layers, 64 hidden):\n",
      "  → Normal gradient: Fast, low memory\n",
      "  → Checkpoint: Overhead not justified\n",
      "\n",
      "Large network (20 layers, 512 hidden):\n",
      "  → Normal gradient: Stores ~20 activations (high memory)\n",
      "  → Checkpoint: Recomputes activations (saves memory)\n",
      "  → Recommended: Use checkpoint for deep/wide networks\n",
      "\n",
      "Rule of thumb:\n",
      "  Use checkpoint when: depth > 10 OR width > 256\n",
      "  Skip checkpoint when: shallow networks (< 5 layers)\n"
     ]
    }
   ],
   "source": [
    "# Example 5: Measuring the tradeoff\n",
    "print(\"\\n=== Example 5: When to use checkpoint ===\")\n",
    "\n",
    "import time\n",
    "\n",
    "class BenchmarkNet(brainstate.nn.Module):\n",
    "    def __init__(self, n_layers, hidden_size):\n",
    "        super().__init__()\n",
    "        self.layers = []\n",
    "        for i in range(n_layers):\n",
    "            self.layers.append(\n",
    "                brainstate.ParamState(jax.random.normal(\n",
    "                    jax.random.PRNGKey(i), \n",
    "                    (hidden_size, hidden_size)\n",
    "                ) * 0.1)\n",
    "            )\n",
    "    \n",
    "    def forward_normal(self, x):\n",
    "        h = x\n",
    "        for W in self.layers:\n",
    "            h = jnp.tanh(h @ W.value)\n",
    "        return jnp.sum(h)\n",
    "    \n",
    "    def forward_checkpointed(self, x):\n",
    "        def layer_block(h):\n",
    "            for W in self.layers:\n",
    "                h = jnp.tanh(h @ W.value)\n",
    "            return jnp.sum(h)\n",
    "        return checkpoint(layer_block)(x)\n",
    "\n",
    "# Small network: checkpoint overhead not worth it\n",
    "small_net = BenchmarkNet(n_layers=3, hidden_size=64)\n",
    "x_small = jax.random.normal(jax.random.PRNGKey(0), (64,))\n",
    "\n",
    "# Large network: checkpoint saves significant memory\n",
    "large_net = BenchmarkNet(n_layers=20, hidden_size=512)\n",
    "x_large = jax.random.normal(jax.random.PRNGKey(0), (512,))\n",
    "\n",
    "print(\"Small network (3 layers, 64 hidden):\")\n",
    "print(\"  → Normal gradient: Fast, low memory\")\n",
    "print(\"  → Checkpoint: Overhead not justified\\n\")\n",
    "\n",
    "print(\"Large network (20 layers, 512 hidden):\")\n",
    "print(\"  → Normal gradient: Stores ~20 activations (high memory)\")\n",
    "print(\"  → Checkpoint: Recomputes activations (saves memory)\")\n",
    "print(\"  → Recommended: Use checkpoint for deep/wide networks\\n\")\n",
    "\n",
    "print(\"Rule of thumb:\")\n",
    "print(\"  Use checkpoint when: depth > 10 OR width > 256\")\n",
    "print(\"  Skip checkpoint when: shallow networks (< 5 layers)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. `make_jaxpr` and `StatefulFunction`: Inspecting Compiled Code\n",
    "\n",
    "`make_jaxpr` converts a function into its JAX intermediate representation (Jaxpr), which reveals how JAX compiles and optimizes your code. `StatefulFunction` is the underlying mechanism that enables state-aware transformations.\n",
    "\n",
    "### What is Jaxpr?\n",
    "\n",
    "Jaxpr is JAX's intermediate representation based on a simply-typed first-order lambda calculus with let-bindings. It shows:\n",
    "- Primitive operations (add, mul, sin, etc.)\n",
    "- Data dependencies\n",
    "- How state reads/writes are handled\n",
    "- Memory layout and optimizations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 Basic Jaxpr Inspection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:44.965751Z",
     "start_time": "2025-10-11T07:37:44.959594Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Example 1: Basic jaxpr ===\n",
      "Function: z = cos(sin(x)) * 2\n",
      "\n",
      "Jaxpr representation:\n",
      "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34;1mlet\n",
      "    \u001b[39;22mb\u001b[35m:f32[]\u001b[39m = sin a\n",
      "    c\u001b[35m:f32[]\u001b[39m = cos b\n",
      "    d\u001b[35m:f32[]\u001b[39m = mul c 2.0:f32[]\n",
      "  \u001b[34;1min \u001b[39;22m(d,) }\n",
      "\n",
      "States used: 0 (none for this simple function)\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Simple function jaxpr\n",
    "print(\"=== Example 1: Basic jaxpr ===\")\n",
    "\n",
    "def simple_fn(x):\n",
    "    y = jnp.sin(x)\n",
    "    z = jnp.cos(y)\n",
    "    return z * 2\n",
    "\n",
    "# Create jaxpr\n",
    "jaxpr_fn = make_jaxpr(simple_fn)\n",
    "jaxpr, states = jaxpr_fn(3.0)\n",
    "\n",
    "print(\"Function: z = cos(sin(x)) * 2\")\n",
    "print(\"\\nJaxpr representation:\")\n",
    "print(jaxpr)\n",
    "print(f\"\\nStates used: {len(states)} (none for this simple function)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Stateful Jaxpr: Tracking State Reads and Writes\n",
    "\n",
    "BrainState's `make_jaxpr` reveals how states are accessed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:46.243143Z",
     "start_time": "2025-10-11T07:37:46.214013Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 2: Stateful jaxpr ===\n",
      "Function: running average tracker\n",
      "\n",
      "States accessed: 2\n",
      "  [0] ShortTermState: 0\n",
      "  [1] ShortTermState: 0.0\n",
      "\n",
      "Jaxpr (state operations visible):\n",
      "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:f32[]\u001b[39m b\u001b[35m:i32[]\u001b[39m c\u001b[35m:f32[]\u001b[39m. \u001b[34;1mlet\n",
      "    \u001b[39;22md\u001b[35m:i32[]\u001b[39m = add b 1:i32[]\n",
      "    e\u001b[35m:f32[]\u001b[39m = add c a\n",
      "    f\u001b[35m:f32[]\u001b[39m = convert_element_type[new_dtype=float32 weak_type=True] d\n",
      "    g\u001b[35m:f32[]\u001b[39m = div e f\n",
      "  \u001b[34;1min \u001b[39;22m(g, d, e) }\n",
      "\n",
      "Note: Jaxpr shows state reads as inputs, writes as outputs\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Jaxpr with states\n",
    "print(\"\\n=== Example 2: Stateful jaxpr ===\")\n",
    "\n",
    "# Create states\n",
    "counter = brainstate.ShortTermState(jnp.array(0))\n",
    "accumulator = brainstate.ShortTermState(jnp.array(0.0))\n",
    "\n",
    "def stateful_fn(x):\n",
    "    # Read states\n",
    "    count = counter.value\n",
    "    accum = accumulator.value\n",
    "    \n",
    "    # Update states\n",
    "    counter.value = count + 1\n",
    "    accumulator.value = accum + x\n",
    "    \n",
    "    return accumulator.value / counter.value\n",
    "\n",
    "# Inspect jaxpr\n",
    "jaxpr_fn = make_jaxpr(stateful_fn)\n",
    "jaxpr, states = jaxpr_fn(5.0)\n",
    "\n",
    "print(\"Function: running average tracker\")\n",
    "print(f\"\\nStates accessed: {len(states)}\")\n",
    "for i, state in enumerate(states):\n",
    "    print(f\"  [{i}] {type(state).__name__}: {state.value}\")\n",
    "\n",
    "print(\"\\nJaxpr (state operations visible):\")\n",
    "print(jaxpr)\n",
    "print(\"\\nNote: Jaxpr shows state reads as inputs, writes as outputs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Understanding `StatefulFunction`\n",
    "\n",
    "`StatefulFunction` is the core abstraction that enables all BrainState transformations. It:\n",
    "1. **Identifies states** accessed during function execution\n",
    "2. **Compiles to Jaxpr** with explicit state inputs/outputs\n",
    "3. **Manages state values** before and after execution\n",
    "4. **Caches compilations** for efficient repeated calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:47.583409Z",
     "start_time": "2025-10-11T07:37:47.291665Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 3: StatefulFunction mechanics ===\n",
      "Step 1: Compilation\n",
      "  Compiled for input shape: (10,)\n",
      "\n",
      "Step 2: State identification\n",
      "  Total states: 2\n",
      "  Read states: 1\n",
      "    - ParamState: shape (10, 20)\n",
      "  Write states: 1\n",
      "    - ShortTermState: shape (20,)\n",
      "\n",
      "Step 3: Jaxpr compilation\n",
      "  Jaxpr variables: 3 inputs, 2 outputs\n",
      "  Jaxpr equations: 3 operations\n",
      "\n",
      "Step 4: Execution\n",
      "  Output shape: (20,)\n",
      "  Hidden state updated: (20,)\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Using StatefulFunction directly\n",
    "print(\"\\n=== Example 3: StatefulFunction mechanics ===\")\n",
    "\n",
    "# Create a module with state\n",
    "class NeuralCell(brainstate.nn.Module):\n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        super().__init__()\n",
    "        self.W = brainstate.ParamState(jax.random.normal(\n",
    "            jax.random.PRNGKey(0), (input_size, hidden_size)\n",
    "        ))\n",
    "        self.h = brainstate.ShortTermState(jnp.zeros(hidden_size))\n",
    "    \n",
    "    def __call__(self, x):\n",
    "        # Update hidden state\n",
    "        self.h.value = jnp.tanh(x @ self.W.value + self.h.value)\n",
    "        return self.h.value\n",
    "\n",
    "cell = NeuralCell(input_size=10, hidden_size=20)\n",
    "\n",
    "# Wrap in StatefulFunction\n",
    "sf = StatefulFunction(cell)\n",
    "\n",
    "# Example input\n",
    "x = jax.random.normal(jax.random.PRNGKey(1), (10,))\n",
    "\n",
    "# Step 1: Compile and inspect\n",
    "sf.make_jaxpr(x)\n",
    "print(\"Step 1: Compilation\")\n",
    "print(f\"  Compiled for input shape: {x.shape}\")\n",
    "\n",
    "# Step 2: Get tracked states\n",
    "states = sf.get_states(x)\n",
    "read_states = sf.get_read_states(x)\n",
    "write_states = sf.get_write_states(x)\n",
    "\n",
    "print(f\"\\nStep 2: State identification\")\n",
    "print(f\"  Total states: {len(states)}\")\n",
    "print(f\"  Read states: {len(read_states)}\")\n",
    "for s in read_states:\n",
    "    print(f\"    - {type(s).__name__}: shape {s.value.shape}\")\n",
    "print(f\"  Write states: {len(write_states)}\")\n",
    "for s in write_states:\n",
    "    print(f\"    - {type(s).__name__}: shape {s.value.shape}\")\n",
    "\n",
    "# Step 3: Get jaxpr\n",
    "jaxpr = sf.get_jaxpr(x)\n",
    "print(f\"\\nStep 3: Jaxpr compilation\")\n",
    "print(f\"  Jaxpr variables: {len(jaxpr.jaxpr.invars)} inputs, {len(jaxpr.jaxpr.outvars)} outputs\")\n",
    "print(f\"  Jaxpr equations: {len(jaxpr.jaxpr.eqns)} operations\")\n",
    "\n",
    "# Step 4: Execute\n",
    "output = sf(x)\n",
    "print(f\"\\nStep 4: Execution\")\n",
    "print(f\"  Output shape: {output.shape}\")\n",
    "print(f\"  Hidden state updated: {cell.h.value.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4 Jaxpr for Gradient Computation\n",
    "\n",
    "Inspect how autodiff transforms your code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:48.668686Z",
     "start_time": "2025-10-11T07:37:48.644312Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 4: Gradient computation jaxpr ===\n",
      "Original function jaxpr:\n",
      "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:f32[3]\u001b[39m b\u001b[35m:f32[3]\u001b[39m. \u001b[34;1mlet\n",
      "    \u001b[39;22mc\u001b[35m:f32[3]\u001b[39m = sub b a\n",
      "    d\u001b[35m:f32[3]\u001b[39m = integer_pow[y=2] c\n",
      "    e\u001b[35m:f32[]\u001b[39m = reduce_sum[axes=(0,) out_sharding=None] d\n",
      "  \u001b[34;1min \u001b[39;22m(e, b) }\n",
      "\n",
      "Gradient function jaxpr:\n",
      "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:f32[3]\u001b[39m b\u001b[35m:f32[3]\u001b[39m. \u001b[34;1mlet\n",
      "    \u001b[39;22mc\u001b[35m:f32[3]\u001b[39m = sub b a\n",
      "    d\u001b[35m:f32[3]\u001b[39m = integer_pow[y=2] c\n",
      "    e\u001b[35m:f32[3]\u001b[39m = integer_pow[y=1] c\n",
      "    f\u001b[35m:f32[3]\u001b[39m = mul 2.0:f32[] e\n",
      "    _\u001b[35m:f32[]\u001b[39m = reduce_sum[axes=(0,) out_sharding=None] d\n",
      "    g\u001b[35m:f32[3]\u001b[39m = broadcast_in_dim[\n",
      "      broadcast_dimensions=()\n",
      "      shape=(3,)\n",
      "      sharding=None\n",
      "    ] 1.0:f32[]\n",
      "    h\u001b[35m:f32[3]\u001b[39m = mul g f\n",
      "  \u001b[34;1min \u001b[39;22m(h, b) }\n",
      "\n",
      "Note: Gradient jaxpr includes:\n",
      "  - Forward pass operations\n",
      "  - Backward pass (VJP) operations\n",
      "  - Much more complex than original\n"
     ]
    }
   ],
   "source": [
    "# Example 4: Gradient jaxpr\n",
    "print(\"\\n=== Example 4: Gradient computation jaxpr ===\")\n",
    "\n",
    "# Simple loss function\n",
    "params = brainstate.ParamState(jnp.array([1.0, 2.0, 3.0]))\n",
    "\n",
    "def loss_fn(x):\n",
    "    return jnp.sum((params.value - x) ** 2)\n",
    "\n",
    "# Original function jaxpr\n",
    "print(\"Original function jaxpr:\")\n",
    "jaxpr_orig, _ = make_jaxpr(loss_fn)(jnp.array([0.5, 1.0, 1.5]))\n",
    "print(jaxpr_orig)\n",
    "\n",
    "# Gradient function jaxpr\n",
    "print(\"\\nGradient function jaxpr:\")\n",
    "grad_fn = brainstate.transform.grad(loss_fn, params)\n",
    "jaxpr_grad, _ = make_jaxpr(grad_fn)(jnp.array([0.5, 1.0, 1.5]))\n",
    "print(jaxpr_grad)\n",
    "\n",
    "print(\"\\nNote: Gradient jaxpr includes:\")\n",
    "print(\"  - Forward pass operations\")\n",
    "print(\"  - Backward pass (VJP) operations\")\n",
    "print(\"  - Much more complex than original\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.5 Jaxpr for Transformed Functions\n",
    "\n",
    "See how transformations affect the compiled code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:50.111196Z",
     "start_time": "2025-10-11T07:37:50.090918Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 5: Transformed function jaxpr ===\n",
      "Original function:\n",
      "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:f32[3]\u001b[39m. \u001b[34;1mlet\u001b[39;22m b\u001b[35m:f32[3]\u001b[39m = integer_pow[y=2] a \u001b[34;1min \u001b[39;22m(b,) }\n",
      "\n",
      "Vmapped function:\n",
      "{ \u001b[34;1mlambda \u001b[39;22m; a\u001b[35m:f32[2,2]\u001b[39m. \u001b[34;1mlet\n",
      "    \u001b[39;22mb\u001b[35m:key<fry>[]\u001b[39m = random_seed[impl=fry] 0:i32[]\n",
      "    c\u001b[35m:u32[2]\u001b[39m = random_unwrap b\n",
      "    d\u001b[35m:key<fry>[]\u001b[39m = random_wrap[impl=fry] c\n",
      "    e\u001b[35m:key<fry>[2]\u001b[39m = random_split[shape=(2,)] d\n",
      "    _\u001b[35m:u32[2,2]\u001b[39m = random_unwrap e\n",
      "    _\u001b[35m:f32[2,2]\u001b[39m = integer_pow[y=2] a\n",
      "    f\u001b[35m:f32[2,2]\u001b[39m = integer_pow[y=2] a\n",
      "  \u001b[34;1min \u001b[39;22m(f,) }\n",
      "\n",
      "Note: vmap adds batching dimensions to operations\n"
     ]
    }
   ],
   "source": [
    "# Example 5: Transformation jaxpr\n",
    "print(\"\\n=== Example 5: Transformed function jaxpr ===\")\n",
    "\n",
    "def simple_fn(x):\n",
    "    return x ** 2\n",
    "\n",
    "# Original\n",
    "print(\"Original function:\")\n",
    "jaxpr1, _ = make_jaxpr(simple_fn)(jnp.array([1.0, 2.0, 3.0]))\n",
    "print(jaxpr1)\n",
    "\n",
    "# Vmapped version\n",
    "print(\"\\nVmapped function:\")\n",
    "vmapped_fn = brainstate.transform.vmap2(simple_fn)\n",
    "jaxpr2, _ = make_jaxpr(vmapped_fn)(jnp.array([[1.0, 2.0], [3.0, 4.0]]))\n",
    "print(jaxpr2)\n",
    "\n",
    "print(\"\\nNote: vmap adds batching dimensions to operations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.6 StatefulFunction Caching\n",
    "\n",
    "`StatefulFunction` caches compiled jaxprs for efficiency."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:51.246750Z",
     "start_time": "2025-10-11T07:37:51.228535Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 6: Compilation caching ===\n",
      "After first compilation:\n",
      "  Jaxpr cache: {'size': 1, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}\n",
      "\n",
      "After same-shape call:\n",
      "  Jaxpr cache: {'size': 1, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}\n",
      "  Hit rate: 0.0%\n",
      "\n",
      "After different-shape call:\n",
      "  Jaxpr cache: {'size': 2, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}\n",
      "  Cache size: 2 entries\n",
      "\n",
      "Caching strategy:\n",
      "  - Different shapes → new compilation\n",
      "  - Same shapes → cache reuse\n",
      "  - Cache size limited to 128 entries (LRU)\n"
     ]
    }
   ],
   "source": [
    "# Example 6: Understanding compilation caching\n",
    "print(\"\\n=== Example 6: Compilation caching ===\")\n",
    "\n",
    "state = brainstate.ShortTermState(jnp.array(0.0))\n",
    "\n",
    "def cached_fn(x):\n",
    "    state.value = state.value + jnp.sum(x)\n",
    "    return state.value\n",
    "\n",
    "sf = StatefulFunction(cached_fn)\n",
    "\n",
    "# First call: compile\n",
    "sf.make_jaxpr(jnp.array([1.0, 2.0]))\n",
    "stats1 = sf.get_cache_stats()\n",
    "print(\"After first compilation:\")\n",
    "print(f\"  Jaxpr cache: {stats1['jaxpr_cache']}\")\n",
    "\n",
    "# Same shape: cache hit\n",
    "sf.make_jaxpr(jnp.array([3.0, 4.0]))\n",
    "stats2 = sf.get_cache_stats()\n",
    "print(\"\\nAfter same-shape call:\")\n",
    "print(f\"  Jaxpr cache: {stats2['jaxpr_cache']}\")\n",
    "print(f\"  Hit rate: {stats2['jaxpr_cache']['hit_rate']:.1f}%\")\n",
    "\n",
    "# Different shape: new compilation\n",
    "sf.make_jaxpr(jnp.array([1.0, 2.0, 3.0]))\n",
    "stats3 = sf.get_cache_stats()\n",
    "print(\"\\nAfter different-shape call:\")\n",
    "print(f\"  Jaxpr cache: {stats3['jaxpr_cache']}\")\n",
    "print(f\"  Cache size: {stats3['jaxpr_cache']['size']} entries\")\n",
    "\n",
    "print(\"\\nCaching strategy:\")\n",
    "print(\"  - Different shapes → new compilation\")\n",
    "print(\"  - Same shapes → cache reuse\")\n",
    "print(\"  - Cache size limited to 128 entries (LRU)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Debugging with `jax.debug.print`\n",
    "\n",
    "`jax.debug.print` enables runtime debugging in JIT-compiled code. Unlike regular `print`, it:\n",
    "- Executes during runtime (not tracing)\n",
    "- Works inside `@jit`, `vmap`, `grad`, etc.\n",
    "- Supports formatted output\n",
    "- Can print array values and shapes\n",
    "\n",
    "### Key principle: Debug prints happen at execution time, not trace time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 Basic Debug Printing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:52.835955Z",
     "start_time": "2025-10-11T07:37:52.802648Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Example 1: Debug printing in JIT ===\n",
      "Input: [1. 2. 3.]\n",
      "After square: [1. 4. 9.]\n",
      "Sum: 14.0\n",
      "\n",
      "Final result: 14.0\n",
      "\n",
      "Note: Debug prints appear during execution, not compilation\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Basic debug printing in JIT\n",
    "print(\"=== Example 1: Debug printing in JIT ===\")\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def compute_with_debug(x):\n",
    "    jax.debug.print(\"Input: {x}\", x=x)\n",
    "    y = x ** 2\n",
    "    jax.debug.print(\"After square: {y}\", y=y)\n",
    "    z = jnp.sum(y)\n",
    "    jax.debug.print(\"Sum: {z}\", z=z)\n",
    "    return z\n",
    "\n",
    "result = compute_with_debug(jnp.array([1.0, 2.0, 3.0]))\n",
    "print(f\"\\nFinal result: {result}\")\n",
    "print(\"\\nNote: Debug prints appear during execution, not compilation\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Debugging State Updates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:54.137187Z",
     "start_time": "2025-10-11T07:37:53.992908Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 2: Debug state updates ===\n",
      "Before update - state: [0. 0. 0.]\n",
      "Computed new state: [ 0.97147846  0.9105761  -0.79975927]\n",
      "After update - state: [ 0.97147846  0.9105761  -0.79975927]\n",
      "\n",
      "Output: [ 0.97147846  0.9105761  -0.79975927]\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Debugging stateful computations\n",
    "print(\"\\n=== Example 2: Debug state updates ===\")\n",
    "\n",
    "class DebuggableCell(brainstate.nn.Module):\n",
    "    def __init__(self, size):\n",
    "        super().__init__()\n",
    "        self.state = brainstate.ShortTermState(jnp.zeros(size))\n",
    "        self.weight = brainstate.ParamState(jax.random.normal(jax.random.PRNGKey(0), (size, size)))\n",
    "    \n",
    "    def step(self, x):\n",
    "        jax.debug.print(\"Before update - state: {s}\", s=self.state.value)\n",
    "        \n",
    "        # Update\n",
    "        new_state = jnp.tanh(x @ self.weight.value + self.state.value)\n",
    "        jax.debug.print(\"Computed new state: {s}\", s=new_state)\n",
    "        \n",
    "        self.state.value = new_state\n",
    "        jax.debug.print(\"After update - state: {s}\", s=self.state.value)\n",
    "        \n",
    "        return new_state\n",
    "\n",
    "cell = DebuggableCell(size=3)\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def update_step(x):\n",
    "    return cell.step(x)\n",
    "\n",
    "x = jnp.array([1.0, 0.0, -1.0])\n",
    "output = update_step(x)\n",
    "print(f\"\\nOutput: {output}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 Debugging Gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:55.231168Z",
     "start_time": "2025-10-11T07:37:55.129160Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 3: Debug gradients ===\n",
      "\n",
      "Computing gradients:\n",
      "Forward - param: [2. 3.], input: [0.5 1. ]\n",
      "Forward - prediction: [1. 3.]\n",
      "Forward - loss: 10.0\n",
      "\n",
      "Gradients: [1. 6.]\n",
      "\n",
      "Note: Debug prints show forward pass values during gradient computation\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Debug gradient computation\n",
    "print(\"\\n=== Example 3: Debug gradients ===\")\n",
    "\n",
    "param = brainstate.ParamState(jnp.array([2.0, 3.0]))\n",
    "\n",
    "def loss_with_debug(x):\n",
    "    jax.debug.print(\"Forward - param: {p}, input: {x}\", p=param.value, x=x)\n",
    "    \n",
    "    pred = param.value * x\n",
    "    jax.debug.print(\"Forward - prediction: {pred}\", pred=pred)\n",
    "    \n",
    "    loss = jnp.sum(pred ** 2)\n",
    "    jax.debug.print(\"Forward - loss: {loss}\", loss=loss)\n",
    "    \n",
    "    return loss\n",
    "\n",
    "# Gradient computation\n",
    "x = jnp.array([0.5, 1.0])\n",
    "grad_fn = brainstate.transform.grad(loss_with_debug, param)\n",
    "\n",
    "print(\"\\nComputing gradients:\")\n",
    "grads = grad_fn(x)\n",
    "print(f\"\\nGradients: {grads}\")\n",
    "print(\"\\nNote: Debug prints show forward pass values during gradient computation\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.4 Debugging Vectorized Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:37:56.503099Z",
     "start_time": "2025-10-11T07:37:56.228085Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 4: Debug vectorized code ===\n",
      "\n",
      "Processing batch:\n",
      "Processing item 0: 1.0\n",
      "Processing item 1: 2.0\n",
      "Processing item 2: 3.0\n",
      "Processing item 3: 4.0\n",
      "Processing item 0: 1.0\n",
      "Processing item 1: 2.0\n",
      "Processing item 2: 3.0\n",
      "Processing item 3: 4.0\n",
      "\n",
      "Results: [ 1.  4.  9. 16.]\n",
      "\n",
      "Note: Debug prints execute for each element in the batch\n"
     ]
    }
   ],
   "source": [
    "# Example 4: Debug vmap\n",
    "print(\"\\n=== Example 4: Debug vectorized code ===\")\n",
    "\n",
    "def process_item(x, index):\n",
    "    jax.debug.print(\"Processing item {i}: {x}\", i=index, x=x)\n",
    "    return x ** 2\n",
    "\n",
    "# Vmap over both arguments\n",
    "vmapped_fn = brainstate.transform.vmap2(process_item)\n",
    "\n",
    "batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])\n",
    "indices = jnp.arange(len(batch_x))\n",
    "\n",
    "print(\"\\nProcessing batch:\")\n",
    "results = vmapped_fn(batch_x, indices)\n",
    "print(f\"\\nResults: {results}\")\n",
    "print(\"\\nNote: Debug prints execute for each element in the batch\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.5 Conditional Debugging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:38:30.549520Z",
     "start_time": "2025-10-11T07:38:30.512063Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 5: Conditional debugging ===\n",
      "\n",
      "Running 10 training steps:\n",
      "Iteration 1: x=[ 1.6226422   2.0252647  -0.43359444 -0.07861735  0.1760909 ]\n",
      "Iteration 2: x=[-0.15443718  0.08470728 -0.13598049 -0.15503626  1.2666674 ]\n",
      "Iteration 3: x=[ 0.36057416  1.2849895  -0.73873436  1.1830745  -0.20641916]\n",
      "Iteration 4: x=[-1.446257    1.539381    0.38250625  1.9707018  -0.5876674 ]\n",
      "Iteration 5: x=[ 1.1777242   0.73848104 -1.0801564   0.3344669   0.00339968]\n",
      "Iteration 6: x=[-0.08437306  1.4110229   0.63048154 -1.3100973   1.3689315 ]\n",
      "Iteration 7: x=[ 0.3864717  -0.57079715 -1.678261   -1.203193    1.0770401 ]\n",
      "Iteration 8: x=[ 0.45123515  1.9534509  -0.51623946 -0.1409403   0.6154967 ]\n",
      "Iteration 9: x=[-0.55150557 -1.369112    2.7549403   0.5639917  -1.0112009 ]\n",
      "Iteration 10: x=[-1.7417272   1.8461128  -0.20227258 -1.27005    -0.7593621 ]\n",
      "\n",
      "Note: Debug prints only at iterations 3, 6, 9\n"
     ]
    }
   ],
   "source": [
    "# Example 5: Conditional debug prints\n",
    "print(\"\\n=== Example 5: Conditional debugging ===\")\n",
    "\n",
    "iteration = brainstate.ShortTermState(jnp.array(0))\n",
    "\n",
    "def training_step_with_debug(x, debug_every=5):\n",
    "    # Update iteration\n",
    "    iteration.value = iteration.value + 1\n",
    "    \n",
    "    # Conditional debug print\n",
    "    jax.debug.print(\n",
    "        \"Iteration {iter}: x={x}\",\n",
    "        iter=iteration.value,\n",
    "        x=x,\n",
    "    )\n",
    "    \n",
    "    loss = jnp.sum(x ** 2)\n",
    "    return loss\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def train_step(x):\n",
    "    return training_step_with_debug(x, debug_every=3)\n",
    "\n",
    "print(\"\\nRunning 10 training steps:\")\n",
    "for i in range(10):\n",
    "    x = jax.random.normal(jax.random.PRNGKey(i), (5,))\n",
    "    loss = train_step(x)\n",
    "\n",
    "print(\"\\nNote: Debug prints only at iterations 3, 6, 9\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.6 Advanced: Custom Debug Callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Example 6: Custom debug callbacks ===\n",
      "\n",
      "Executing with custom debug callbacks:\n",
      "[DEBUG input]:\n",
      "  Shape: (5, 5)\n",
      "  Dtype: float32\n",
      "  Min: -1.9389\n",
      "  Max: 1.4458\n",
      "  Mean: 0.1084\n",
      "  Std: 0.8442\n",
      "[DEBUG after_tanh]:\n",
      "  Shape: (5, 5)\n",
      "  Dtype: float32\n",
      "  Min: -0.9594\n",
      "  Max: 0.8949\n",
      "  Mean: 0.1294\n",
      "  Std: 0.5656\n",
      "[DEBUG output]:\n",
      "  Shape: (5, 5)\n",
      "  Dtype: float32\n",
      "  Min: -1.2547\n",
      "  Max: 2.2312\n",
      "  Mean: 0.1165\n",
      "  Std: 0.9866\n",
      "\n",
      "Final result shape: (5, 5)\n"
     ]
    }
   ],
   "source": [
    "# Example 6: Custom debugging with callbacks\n",
    "print(\"\\n=== Example 6: Custom debug callbacks ===\")\n",
    "\n",
    "def custom_debug_callback(name, value):\n",
    "    \"\"\"Custom callback for detailed debugging.\"\"\"\n",
    "    print(f\"[DEBUG {name}]:\")\n",
    "    print(f\"  Shape: {value.shape}\")\n",
    "    print(f\"  Dtype: {value.dtype}\")\n",
    "    print(f\"  Min: {jnp.min(value):.4f}\")\n",
    "    print(f\"  Max: {jnp.max(value):.4f}\")\n",
    "    print(f\"  Mean: {jnp.mean(value):.4f}\")\n",
    "    print(f\"  Std: {jnp.std(value):.4f}\")\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def compute_with_callback(x):\n",
    "    # Use debug callback for detailed inspection\n",
    "    jax.debug.callback(custom_debug_callback, \"input\", x)\n",
    "    \n",
    "    y = jnp.tanh(x)\n",
    "    jax.debug.callback(custom_debug_callback, \"after_tanh\", y)\n",
    "    \n",
    "    z = y @ y.T\n",
    "    jax.debug.callback(custom_debug_callback, \"output\", z)\n",
    "    \n",
    "    return z\n",
    "\n",
    "x = jax.random.normal(jax.random.PRNGKey(42), (5, 5))\n",
    "print(\"\\nExecuting with custom debug callbacks:\")\n",
    "result = compute_with_callback(x)\n",
    "print(f\"\\nFinal result shape: {result.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "This tutorial covered three essential BrainState utilities:\n",
    "\n",
    "### 1. `checkpoint`: Memory-Efficient Gradients\n",
    "- **Purpose**: Reduce memory usage during gradient computation\n",
    "- **Mechanism**: Recompute activations during backward pass instead of storing them\n",
    "- **Tradeoff**: ~2x computation for significant memory savings (O(√n) vs O(n))\n",
    "- **When to use**: Deep networks (>10 layers), wide networks (>256 hidden), long sequences\n",
    "- **Advanced**: Custom policies control what to save vs. recompute\n",
    "- **State-aware**: Works seamlessly with BrainState's `State` objects\n",
    "\n",
    "### 2. `make_jaxpr` and `StatefulFunction`: Code Inspection\n",
    "- **Purpose**: Understand how JAX compiles and optimizes your code\n",
    "- **Jaxpr**: JAX's intermediate representation showing primitive operations and data flow\n",
    "- **StatefulFunction**: Core mechanism enabling all BrainState transformations\n",
    "  - Identifies state reads and writes\n",
    "  - Compiles to Jaxpr with explicit state handling\n",
    "  - Caches compilations for efficiency (LRU cache, 128 entries)\n",
    "  - Manages state values automatically\n",
    "- **Use cases**: Debugging compilation issues, understanding transformations, optimization analysis\n",
    "\n",
    "### 3. `jax.debug.print`: Runtime Debugging\n",
    "- **Purpose**: Debug JIT-compiled code during execution\n",
    "- **Key features**:\n",
    "  - Prints at runtime (not trace time)\n",
    "  - Works inside `@jit`, `vmap`, `grad`, etc.\n",
    "  - Supports formatted output and array inspection\n",
    "- **Best practices**:\n",
    "  - Use debug flags to enable/disable\n",
    "  - Print statistics, not full arrays\n",
    "  - Check for NaN/Inf in critical ops\n",
    "  - Use callbacks for complex debugging\n",
    "  - Disable in production\n",
    "\n",
    "### Integration with BrainState\n",
    "All three tools are **state-aware**:\n",
    "- `checkpoint` preserves state semantics during rematerialization\n",
    "- `make_jaxpr` reveals state reads/writes in compiled code\n",
    "- `jax.debug.print` can inspect state values during execution\n",
    "\n",
    "These utilities are essential for developing, optimizing, and debugging complex stateful models in BrainState."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
