{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Automatic Differentiation\n",
    "\n",
    "BrainState provides a comprehensive automatic differentiation system built on top of JAX, designed specifically for stateful computations. This tutorial focuses on `brainstate.transform.grad` and related gradient transformations, demonstrating how to compute gradients with respect to function arguments and `State` objects.\n",
    "\n",
    "## Key Concepts\n",
    "\n",
    "BrainState's gradient system revolves around two key concepts:\n",
    "\n",
    "1. **`argnums`**: Select which function arguments to differentiate with respect to (inherited from JAX)\n",
    "2. **`grad_states`**: Select which `State` objects should receive gradients (BrainState's extension)\n",
    "\n",
    "Additionally, BrainState uses **`ParamState`** to mark trainable parameters in neural networks and provides utilities to discover and manage states in arbitrary functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:33:21.923144Z",
     "start_time": "2025-10-11T03:33:21.919200Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "from brainstate.transform import grad, StateFinder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Understanding `argnums`: Gradients w.r.t. Function Arguments\n",
    "\n",
    "The `argnums` parameter works just like in JAX's `jax.grad`. It specifies which positional arguments to differentiate with respect to."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:33:24.779450Z",
     "start_time": "2025-10-11T03:33:24.761445Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gradient w.r.t. x: [2. 2. 2.]\n",
      "Gradient w.r.t. x: [2. 2. 2.]\n",
      "Gradient w.r.t. y: [-2. -2. -2.]\n"
     ]
    }
   ],
   "source": [
    "def loss_fn(x, y, scale):\n",
    "    \"\"\"Simple loss function with multiple arguments.\"\"\"\n",
    "    return scale * jnp.sum((x - y) ** 2)\n",
    "\n",
    "x = jnp.array([1.0, 2.0, 3.0])\n",
    "y = jnp.array([0.5, 1.5, 2.5])\n",
    "scale = 2.0\n",
    "\n",
    "# Gradient w.r.t. the first argument (x)\n",
    "grad_fn_x = grad(loss_fn, argnums=0)\n",
    "grad_x = grad_fn_x(x, y, scale)\n",
    "print(\"Gradient w.r.t. x:\", grad_x)\n",
    "\n",
    "# Gradient w.r.t. multiple arguments\n",
    "grad_fn_xy = grad(loss_fn, argnums=[0, 1])\n",
    "grad_x, grad_y = grad_fn_xy(x, y, scale)\n",
    "print(\"Gradient w.r.t. x:\", grad_x)\n",
    "print(\"Gradient w.r.t. y:\", grad_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Understanding `grad_states`: Gradients w.r.t. State Objects\n",
    "\n",
    "### 2.1 ParamState for Trainable Parameters\n",
    "\n",
    "In BrainState, **`ParamState`** is used to mark parameters that should receive gradients during training. This is the standard way to define trainable parameters in neural network modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:34:06.293839Z",
     "start_time": "2025-10-11T03:34:06.269236Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 5.5000\n",
      "\n",
      "Parameter gradients:\n",
      "  ('bias',): [-2.]\n",
      "  ('weight',): [[-3.]]\n"
     ]
    }
   ],
   "source": [
    "class LinearRegressor(brainstate.nn.Module):\n",
    "    \"\"\"Simple linear regression model.\"\"\"\n",
    "    \n",
    "    def __init__(self, in_features: int, out_features: int = 1):\n",
    "        super().__init__()\n",
    "        # ParamState marks these as trainable parameters\n",
    "        self.weight = brainstate.ParamState(jnp.zeros((in_features, out_features)))\n",
    "        self.bias = brainstate.ParamState(jnp.zeros((out_features,)))\n",
    "\n",
    "    def __call__(self, x: jax.Array) -> jax.Array:\n",
    "        return x @ self.weight.value + self.bias.value\n",
    "\n",
    "\n",
    "# Create model and training data\n",
    "model = LinearRegressor(1)\n",
    "xs = jnp.linspace(-1.0, 1.0, 5).reshape(-1, 1)\n",
    "y_true = 3.0 * xs + 1.0\n",
    "\n",
    "\n",
    "def mse_loss(x: jax.Array, target: jax.Array) -> jax.Array:\n",
    "    \"\"\"Mean squared error loss.\"\"\"\n",
    "    pred = model(x)\n",
    "    return jnp.mean((pred - target) ** 2)\n",
    "\n",
    "\n",
    "# Compute gradients w.r.t. model parameters\n",
    "loss_grad = grad(\n",
    "    mse_loss,\n",
    "    grad_states=model.states(brainstate.ParamState),  # Get all ParamState instances\n",
    "    return_value=True,\n",
    ")\n",
    "\n",
    "param_grads, loss_value = loss_grad(xs, y_true)\n",
    "print(f\"Loss: {float(loss_value):.4f}\")\n",
    "print(\"\\nParameter gradients:\")\n",
    "for path, g in param_grads.items():\n",
    "    print(f\"  {path}: {g}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Retrieving States from Modules\n",
    "\n",
    "BrainState provides two main ways to retrieve states from modules:\n",
    "\n",
    "1. **`module.states(*filter)`**: Get states directly from a `Module` instance\n",
    "2. **`brainstate.graph.treefy_states(node, *filter)`**: Get states from any object (more general)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:34:32.456863Z",
     "start_time": "2025-10-11T03:34:32.436565Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using model.states():\n",
      "  ('bias',): shape=(1,)\n",
      "  ('weight',): shape=(1, 1)\n",
      "\n",
      "Using brainstate.graph.treefy_states():\n",
      "  ('bias',): shape=(1,)\n",
      "  ('weight',): shape=(1, 1)\n"
     ]
    }
   ],
   "source": [
    "# Method 1: Using module.states()\n",
    "params_method1 = model.states(brainstate.ParamState)\n",
    "print(\"Using model.states():\")\n",
    "for path, state in params_method1.items():\n",
    "    print(f\"  {path}: shape={state.value.shape}\")\n",
    "\n",
    "# Method 2: Using brainstate.graph.treefy_states()\n",
    "params_method2 = brainstate.graph.treefy_states(model, brainstate.ParamState)\n",
    "print(\"\\nUsing brainstate.graph.treefy_states():\")\n",
    "for path, state in params_method2.to_flat().items():\n",
    "    print(f\"  {path}: shape={state.value.shape}\")\n",
    "\n",
    "# Both methods return the same states\n",
    "assert set(params_method1.keys()) == set(params_method2.to_flat().keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Using StateFinder for Arbitrary Functions\n",
    "\n",
    "Not every function is a `nn.Module`. For arbitrary functions, you can use **`StateFinder`** to discover which states are used inside the function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:35:42.407961Z",
     "start_time": "2025-10-11T03:35:42.391971Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "States found by StateFinder:\n",
      "  0: scale\n",
      "  1: offset\n",
      "\n",
      "Energy: 20.1800\n",
      "Gradients:\n",
      "  0: 28.400001525878906\n",
      "  1: 11.200000762939453\n"
     ]
    }
   ],
   "source": [
    "# Create some standalone states\n",
    "scale = brainstate.ParamState(jnp.array(1.5), name=\"scale\")\n",
    "offset = brainstate.ParamState(jnp.array(-0.2), name=\"offset\")\n",
    "cache = brainstate.State(jnp.array(0.0), name=\"cache\")  # Not a ParamState\n",
    "\n",
    "\n",
    "def energy(x: jax.Array) -> jax.Array:\n",
    "    \"\"\"Energy function using external states.\"\"\"\n",
    "    shifted = x * scale.value + offset.value\n",
    "    # Update a state to track it as a write operation\n",
    "    scale.value = scale.value + 0.0  # Dummy update to mark as written\n",
    "    cache.value = jnp.sum(shifted)  # Write to cache\n",
    "    return jnp.sum(jnp.square(shifted))\n",
    "\n",
    "\n",
    "# Use StateFinder to discover states used in the function\n",
    "finder = StateFinder(\n",
    "    energy,\n",
    "    filter=brainstate.ParamState,  # Only find ParamState instances\n",
    "    usage='all',  # Find both read and write states\n",
    "    return_type='dict',  # Return as a dictionary\n",
    ")\n",
    "\n",
    "all_param_states = finder(jnp.ones((2,)))\n",
    "print(\"States found by StateFinder:\")\n",
    "for name, state in all_param_states.items():\n",
    "    print(f\"  {name}: {state.name}\")\n",
    "\n",
    "# Now compute gradients w.r.t. these discovered states\n",
    "energy_grad = grad(\n",
    "    energy,\n",
    "    grad_states=all_param_states,\n",
    "    return_value=True,\n",
    ")\n",
    "\n",
    "state_grads, energy_value = energy_grad(jnp.array([1.0, 3.0]))\n",
    "print(f\"\\nEnergy: {float(energy_value):.4f}\")\n",
    "print(\"Gradients:\")\n",
    "for idx, key in enumerate(state_grads):\n",
    "    st = all_param_states[key]\n",
    "    print(f\"  {key}: {state_grads[key]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4 Important Note: Gradients are Not Limited to ParamState\n",
    "\n",
    "While `ParamState` is the standard way to mark trainable parameters, **gradient computation works with any `State` instance**. You can compute gradients w.r.t. any `State` object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:35:52.996206Z",
     "start_time": "2025-10-11T03:35:52.942783Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gradient w.r.t. regular State: 56.0\n"
     ]
    }
   ],
   "source": [
    "# Create a regular State (not ParamState)\n",
    "regular_state = brainstate.State(jnp.array(2.0), name=\"regular_state\")\n",
    "\n",
    "\n",
    "def compute_with_state(x):\n",
    "    return jnp.sum((x * regular_state.value) ** 2)\n",
    "\n",
    "\n",
    "# Compute gradient w.r.t. regular State\n",
    "grad_fn = grad(compute_with_state, grad_states=[regular_state])\n",
    "gradient = grad_fn(jnp.array([1.0, 2.0, 3.0]))\n",
    "print(f\"Gradient w.r.t. regular State: {gradient[0]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Combining `argnums` and `grad_states`\n",
    "\n",
    "You can compute gradients with respect to both function arguments and states simultaneously."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:36:18.972107Z",
     "start_time": "2025-10-11T03:36:18.946669Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 5.5000\n",
      "Gradient w.r.t. l2_coeff: 0.0000\n",
      "\n",
      "State gradients:\n",
      "  ('bias',): [-2.]\n",
      "  ('weight',): [[-3.]]\n"
     ]
    }
   ],
   "source": [
    "reg_model = LinearRegressor(1)\n",
    "\n",
    "\n",
    "def penalized_loss(l2_coeff: float, inputs: jax.Array, target: jax.Array) -> jax.Array:\n",
    "    \"\"\"Loss with L2 regularization.\"\"\"\n",
    "    pred = reg_model(inputs)\n",
    "    mse = jnp.mean((pred - target) ** 2)\n",
    "    # L2 penalty on parameters\n",
    "    l2 = jnp.sum(reg_model.weight.value ** 2) + jnp.sum(reg_model.bias.value ** 2)\n",
    "    return mse + l2_coeff * l2\n",
    "\n",
    "\n",
    "# Compute gradients w.r.t. both states and the first argument\n",
    "grad_penalized = grad(\n",
    "    penalized_loss,\n",
    "    grad_states=reg_model.states(brainstate.ParamState),\n",
    "    argnums=0,  # Also differentiate w.r.t. l2_coeff\n",
    "    return_value=True,\n",
    ")\n",
    "\n",
    "(state_grads, coeff_grad), loss_val = grad_penalized(0.5, xs, y_true)\n",
    "print(f\"Loss: {float(loss_val):.4f}\")\n",
    "print(f\"Gradient w.r.t. l2_coeff: {float(coeff_grad):.4f}\")\n",
    "print(\"\\nState gradients:\")\n",
    "for path, g in state_grads.items():\n",
    "    print(f\"  {path}: {g}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Return Value Structures\n",
    "\n",
    "All gradient transformations in BrainState share a common signature pattern. The return structure depends on the combination of `grad_states`, `argnums`, `has_aux`, and `return_value`.\n",
    "\n",
    "### 4.1 Basic Return Structures\n",
    "\n",
    "When `grad_states` is None:\n",
    "\n",
    "- `has_aux=False` + `return_value=False` → `arg_grads`\n",
    "- `has_aux=True` + `return_value=False` → `(arg_grads, aux_data)`\n",
    "- `has_aux=False` + `return_value=True` → `(arg_grads, loss_value)`\n",
    "- `has_aux=True` + `return_value=True` → `(arg_grads, loss_value, aux_data)`\n",
    "\n",
    "When `grad_states` is not None and `argnums` is None:\n",
    "\n",
    "- `has_aux=False` + `return_value=False` → `var_grads`\n",
    "- `has_aux=True` + `return_value=False` → `(var_grads, aux_data)`\n",
    "- `has_aux=False` + `return_value=True` → `(var_grads, loss_value)`\n",
    "- `has_aux=True` + `return_value=True` → `(var_grads, loss_value, aux_data)`\n",
    "\n",
    "When both `grad_states` and `argnums` are not None:\n",
    "\n",
    "- `has_aux=False` + `return_value=False` → `(var_grads, arg_grads)`\n",
    "- `has_aux=True` + `return_value=False` → `((var_grads, arg_grads), aux_data)`\n",
    "- `has_aux=False` + `return_value=True` → `((var_grads, arg_grads), loss_value)`\n",
    "- `has_aux=True` + `return_value=True` → `((var_grads, arg_grads), loss_value, aux_data)`\n",
    "\n",
    "\n",
    "List them as a table for clarity:\n",
    "\n",
    "| grad_states | argnums | has_aux | return_value | result |\n",
    "|-------------|---------|---------|--------------|--------|\n",
    "| `None` | any | `False` | `False` | `arg_grads` |\n",
    "| `None` | any | `True` | `False` | `(arg_grads, aux)` |\n",
    "| `None` | any | `False` | `True` | `(arg_grads, loss)` |\n",
    "| `None` | any | `True` | `True` | `(arg_grads, loss, aux)` |\n",
    "| not `None` | `None` | `False` | `False` | `var_grads` |\n",
    "| not `None` | `None` | `True` | `False` | `(var_grads, aux)` |\n",
    "| not `None` | `None` | `False` | `True` | `(var_grads, loss)` |\n",
    "| not `None` | `None` | `True` | `True` | `(var_grads, loss, aux)` |\n",
    "| not `None` | not `None` | `False` | `False` | `(var_grads, arg_grads)` |\n",
    "| not `None` | not `None` | `True` | `False` | `((var_grads, arg_grads), aux)` |\n",
    "| not `None` | not `None` | `False` | `True` | `((var_grads, arg_grads), loss)` |\n",
    "| not `None` | not `None` | `True` | `True` | `((var_grads, arg_grads), loss, aux)` |\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 Complete Example: All Return Options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:37:25.452877Z",
     "start_time": "2025-10-11T03:37:25.339953Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 5.5000\n",
      "\n",
      "Gradient w.r.t. l2_coeff: 0.0000\n",
      "\n",
      "State gradients:\n",
      "  ('bias',): [-2.]\n",
      "  ('weight',): [[-3.]]\n",
      "\n",
      "Auxiliary metrics:\n",
      "  l2: 0.0000\n",
      "  mae: 2.0000\n",
      "  mse: 5.5000\n"
     ]
    }
   ],
   "source": [
    "example_model = LinearRegressor(1)\n",
    "\n",
    "\n",
    "def loss_with_metrics(l2_coeff: float, x: jax.Array, target: jax.Array):\n",
    "    \"\"\"Loss function that returns auxiliary metrics.\"\"\"\n",
    "    pred = example_model(x)\n",
    "    mse = jnp.mean((pred - target) ** 2)\n",
    "    l2 = jnp.sum(example_model.weight.value ** 2)\n",
    "    loss = mse + l2_coeff * l2\n",
    "    \n",
    "    # Return loss and auxiliary metrics\n",
    "    metrics = {\n",
    "        \"mae\": jnp.mean(jnp.abs(pred - target)),\n",
    "        \"mse\": mse,\n",
    "        \"l2\": l2,\n",
    "    }\n",
    "    return loss, metrics\n",
    "\n",
    "\n",
    "# Example: grad_states + argnums + has_aux + return_value\n",
    "grad_complete = grad(\n",
    "    loss_with_metrics,\n",
    "    grad_states=example_model.states(brainstate.ParamState),\n",
    "    argnums=0,\n",
    "    has_aux=True,\n",
    "    return_value=True,\n",
    ")\n",
    "\n",
    "((state_grads, coeff_grad), loss_val, aux_metrics) = grad_complete(0.3, xs, y_true)\n",
    "\n",
    "print(f\"Loss: {float(loss_val):.4f}\")\n",
    "print(f\"\\nGradient w.r.t. l2_coeff: {float(coeff_grad):.4f}\")\n",
    "print(\"\\nState gradients:\")\n",
    "for path, g in state_grads.items():\n",
    "    print(f\"  {path}: {g}\")\n",
    "print(\"\\nAuxiliary metrics:\")\n",
    "for key, val in aux_metrics.items():\n",
    "    print(f\"  {key}: {float(val):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Other Gradient Transformations\n",
    "\n",
    "BrainState provides several other gradient transformations, all sharing the same signature pattern as `grad`.\n",
    "\n",
    "### 5.1 Vector Gradient\n",
    "\n",
    "`vector_grad` is used for vector-valued functions. It computes the sum of gradients across all output dimensions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:37:35.506574Z",
     "start_time": "2025-10-11T03:37:35.189027Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vector gradient: [4.5403023 5.       ]\n"
     ]
    }
   ],
   "source": [
    "from brainstate.transform import vector_grad\n",
    "\n",
    "\n",
    "def vector_fun(x):\n",
    "    \"\"\"Vector-valued function.\"\"\"\n",
    "    return jnp.array([x[0] * x[1], jnp.sin(x[0]), x[0]**2 + x[1]**2])\n",
    "\n",
    "\n",
    "x0 = jnp.array([1.0, 2.0])\n",
    "\n",
    "# Vector gradient sums gradients across all outputs\n",
    "vgrad = vector_grad(vector_fun)\n",
    "result = vgrad(x0)\n",
    "print(\"Vector gradient:\", result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 Jacobians: `jacrev` and `jacfwd`\n",
    "\n",
    "- **`jacrev`**: Jacobian using reverse-mode autodiff (efficient for many inputs, few outputs)\n",
    "- **`jacfwd`**: Jacobian using forward-mode autodiff (efficient for few inputs, many outputs)\n",
    "- **`jacobian`**: Alias for `jacrev`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:37:40.121490Z",
     "start_time": "2025-10-11T03:37:39.354116Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Jacobian (reverse-mode):\n",
      "[[2.        1.       ]\n",
      " [0.5403023 0.       ]\n",
      " [0.        7.389056 ]]\n",
      "\n",
      "Jacobian (forward-mode):\n",
      "[[2.        1.       ]\n",
      " [0.5403023 0.       ]\n",
      " [0.        7.389056 ]]\n"
     ]
    }
   ],
   "source": [
    "from brainstate.transform import jacrev, jacfwd, jacobian\n",
    "\n",
    "\n",
    "def multi_output(x):\n",
    "    \"\"\"Function with multiple outputs.\"\"\"\n",
    "    return jnp.array([x[0] * x[1], jnp.sin(x[0]), jnp.exp(x[1])])\n",
    "\n",
    "\n",
    "x0 = jnp.array([1.0, 2.0])\n",
    "\n",
    "# Reverse-mode Jacobian\n",
    "jac_rev = jacrev(multi_output)\n",
    "result_rev = jac_rev(x0)\n",
    "print(\"Jacobian (reverse-mode):\")\n",
    "print(result_rev)\n",
    "\n",
    "# Forward-mode Jacobian\n",
    "jac_fwd = jacfwd(multi_output)\n",
    "result_fwd = jac_fwd(x0)\n",
    "print(\"\\nJacobian (forward-mode):\")\n",
    "print(result_fwd)\n",
    "\n",
    "# They should be the same\n",
    "assert jnp.allclose(result_rev, result_fwd)\n",
    "\n",
    "# jacobian is an alias for jacrev\n",
    "jac_alias = jacobian(multi_output)\n",
    "result_alias = jac_alias(x0)\n",
    "assert jnp.allclose(result_rev, result_alias)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.3 Hessian\n",
    "\n",
    "`hessian` computes second-order derivatives."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:37:43.567061Z",
     "start_time": "2025-10-11T03:37:42.929574Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hessian:\n",
      "[[2. 0.]\n",
      " [0. 2.]]\n"
     ]
    }
   ],
   "source": [
    "from brainstate.transform import hessian\n",
    "\n",
    "\n",
    "def quadratic(x):\n",
    "    \"\"\"Quadratic function.\"\"\"\n",
    "    return jnp.dot(x, x)\n",
    "\n",
    "\n",
    "x0 = jnp.array([1.0, 2.0])\n",
    "\n",
    "hess_fn = hessian(quadratic)\n",
    "result = hess_fn(x0)\n",
    "print(\"Hessian:\")\n",
    "print(result)\n",
    "\n",
    "# For a quadratic form x^T x, the Hessian is 2*I\n",
    "expected = 2 * jnp.eye(2)\n",
    "assert jnp.allclose(result, expected)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.4 Using Gradient Transformations with States"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:37:55.603858Z",
     "start_time": "2025-10-11T03:37:55.585743Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Jacobian w.r.t. parameters:\n",
      "  ('bias',): shape=(1, 1)\n",
      "  ('weight',): shape=(1, 2, 1)\n"
     ]
    }
   ],
   "source": [
    "# Example: Jacobian with states\n",
    "jac_model = LinearRegressor(2)\n",
    "\n",
    "\n",
    "def model_output(x):\n",
    "    \"\"\"Multiple outputs from a model.\"\"\"\n",
    "    return jac_model(x)\n",
    "\n",
    "\n",
    "# Compute Jacobian w.r.t. model parameters\n",
    "jac_states = jacrev(\n",
    "    model_output,\n",
    "    grad_states=jac_model.states(brainstate.ParamState)\n",
    ")\n",
    "\n",
    "x_input = jnp.array([1.0, 2.0])\n",
    "param_jacobian = jac_states(x_input)\n",
    "\n",
    "print(\"Jacobian w.r.t. parameters:\")\n",
    "for path, jac in param_jacobian.items():\n",
    "    print(f\"  {path}: shape={jac.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Custom Gradient Transformations with `GradientTransform`\n",
    "\n",
    "You can create custom gradient transformations by using the `GradientTransform` class. This allows you to wrap any JAX gradient function while maintaining BrainState's state-aware behavior.\n",
    "\n",
    "### 6.1 Basic Custom Transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:38:23.544684Z",
     "start_time": "2025-10-11T03:38:23.515793Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scaled gradients:\n",
      "  ('bias',): [-1.]\n",
      "  ('weight',): [[-1.5]]\n",
      "\n",
      "Normal gradients:\n",
      "  ('bias',): [-2.]\n",
      "  ('weight',): [[-3.]]\n"
     ]
    }
   ],
   "source": [
    "from brainstate.transform import GradientTransform\n",
    "\n",
    "\n",
    "def scaled_grad_transform(fun, *, argnums, has_aux, scale):\n",
    "    \"\"\"Custom gradient transform that scales gradients.\"\"\"\n",
    "    # Use JAX's grad as the base transformation\n",
    "    base = jax.grad(fun, argnums=argnums, has_aux=True)\n",
    "\n",
    "    def wrapped(*args, **kwargs):\n",
    "        grads, aux = base(*args, **kwargs)\n",
    "        # Scale all gradients\n",
    "        grads = jax.tree.map(lambda g: scale * g, grads)\n",
    "        return grads, aux\n",
    "\n",
    "    return wrapped\n",
    "\n",
    "\n",
    "def scaled_grad(\n",
    "    fun,\n",
    "    *,\n",
    "    scale=1.0,\n",
    "    grad_states=None,\n",
    "    argnums=None,\n",
    "    has_aux=False,\n",
    "    return_value=False,\n",
    "):\n",
    "    \"\"\"Create a gradient function with scaled gradients.\"\"\"\n",
    "    return GradientTransform(\n",
    "        fun,\n",
    "        transform=scaled_grad_transform,\n",
    "        grad_states=grad_states,\n",
    "        argnums=argnums,\n",
    "        has_aux=has_aux,\n",
    "        return_value=return_value,\n",
    "        transform_params={\"scale\": scale},  # Pass custom parameters\n",
    "    )\n",
    "\n",
    "\n",
    "# Example usage\n",
    "custom_model = LinearRegressor(1)\n",
    "\n",
    "\n",
    "def custom_loss(x, target):\n",
    "    pred = custom_model(x)\n",
    "    return jnp.mean((pred - target) ** 2)\n",
    "\n",
    "\n",
    "# Use custom scaled gradient\n",
    "scaled_grad_fn = scaled_grad(\n",
    "    custom_loss,\n",
    "    scale=0.5,  # Scale gradients by 0.5\n",
    "    grad_states=custom_model.states(brainstate.ParamState),\n",
    ")\n",
    "\n",
    "scaled_grads = scaled_grad_fn(xs, y_true)\n",
    "print(\"Scaled gradients:\")\n",
    "for path, g in scaled_grads.items():\n",
    "    print(f\"  {path}: {g}\")\n",
    "\n",
    "# Compare with unscaled gradients\n",
    "normal_grad_fn = grad(custom_loss, grad_states=custom_model.states(brainstate.ParamState))\n",
    "normal_grads = normal_grad_fn(xs, y_true)\n",
    "print(\"\\nNormal gradients:\")\n",
    "for path, g in normal_grads.items():\n",
    "    print(f\"  {path}: {g}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 Advanced: Gradient Clipping Transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:38:35.759607Z",
     "start_time": "2025-10-11T03:38:35.357417Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Clipped gradients:\n",
      "  ('bias',): [-0.05547]\n",
      "    norm: 0.0555\n",
      "  ('weight',): [[-0.08320501]]\n",
      "    norm: 0.0832\n"
     ]
    }
   ],
   "source": [
    "def clipped_grad_transform(fun, *, argnums, has_aux, max_norm):\n",
    "    \"\"\"Custom gradient transform with gradient clipping.\"\"\"\n",
    "    base = jax.grad(fun, argnums=argnums, has_aux=True)\n",
    "\n",
    "    def wrapped(*args, **kwargs):\n",
    "        grads, aux = base(*args, **kwargs)\n",
    "        \n",
    "        # Compute global norm\n",
    "        global_norm = jnp.sqrt(\n",
    "            sum(jnp.sum(jnp.square(g)) for g in jax.tree.leaves(grads))\n",
    "        )\n",
    "        \n",
    "        # Clip gradients\n",
    "        scale = jnp.minimum(1.0, max_norm / (global_norm + 1e-6))\n",
    "        grads = jax.tree.map(lambda g: scale * g, grads)\n",
    "        \n",
    "        return grads, aux\n",
    "\n",
    "    return wrapped\n",
    "\n",
    "\n",
    "def clipped_grad(\n",
    "    fun,\n",
    "    *,\n",
    "    max_norm=1.0,\n",
    "    grad_states=None,\n",
    "    argnums=None,\n",
    "    has_aux=False,\n",
    "    return_value=False,\n",
    "):\n",
    "    \"\"\"Create a gradient function with gradient clipping.\"\"\"\n",
    "    return GradientTransform(\n",
    "        fun,\n",
    "        transform=clipped_grad_transform,\n",
    "        grad_states=grad_states,\n",
    "        argnums=argnums,\n",
    "        has_aux=has_aux,\n",
    "        return_value=return_value,\n",
    "        transform_params={\"max_norm\": max_norm},\n",
    "    )\n",
    "\n",
    "\n",
    "# Example: gradient clipping\n",
    "clip_model = LinearRegressor(1)\n",
    "\n",
    "\n",
    "def clip_loss(x, target):\n",
    "    pred = clip_model(x)\n",
    "    return jnp.mean((pred - target) ** 2)\n",
    "\n",
    "\n",
    "clipped_grad_fn = clipped_grad(\n",
    "    clip_loss,\n",
    "    max_norm=0.1,  # Clip gradients to max norm of 0.1\n",
    "    grad_states=clip_model.states(brainstate.ParamState),\n",
    ")\n",
    "\n",
    "clipped_grads = clipped_grad_fn(xs, y_true)\n",
    "print(\"Clipped gradients:\")\n",
    "for path, g in clipped_grads.items():\n",
    "    print(f\"  {path}: {g}\")\n",
    "    print(f\"    norm: {jnp.linalg.norm(g):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Practical Example: Training Loop\n",
    "\n",
    "Let's put everything together in a complete training example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T03:40:55.529291Z",
     "start_time": "2025-10-11T03:40:55.371512Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training started...\n",
      "Initial weight: [[0.]]\n",
      "Initial bias: [0.]\n",
      "\n",
      "Epoch 10:\n",
      "  Loss: 0.9506\n",
      "  MSE: 0.9198\n",
      "  L2: 0.030758\n",
      "  Weight: [[1.6285777]]\n",
      "  Bias: [0.9066533]\n",
      "\n",
      "Epoch 20:\n",
      "  Loss: 0.2827\n",
      "  MSE: 0.2190\n",
      "  L2: 0.063762\n",
      "  Weight: [[2.3699088]]\n",
      "  Bias: [1.0015979]\n",
      "\n",
      "Epoch 30:\n",
      "  Loss: 0.1478\n",
      "  MSE: 0.0655\n",
      "  L2: 0.082280\n",
      "  Weight: [[2.7073638]]\n",
      "  Bias: [1.0115404]\n",
      "\n",
      "Epoch 40:\n",
      "  Loss: 0.1199\n",
      "  MSE: 0.0284\n",
      "  L2: 0.091504\n",
      "  Weight: [[2.8609738]]\n",
      "  Bias: [1.0125817]\n",
      "\n",
      "Epoch 50:\n",
      "  Loss: 0.1141\n",
      "  MSE: 0.0182\n",
      "  L2: 0.095877\n",
      "  Weight: [[2.9308972]]\n",
      "  Bias: [1.0126907]\n",
      "\n",
      "Training completed!\n",
      "Final weight: [[2.9308972]] (true: 3.0)\n",
      "Final bias: [1.0126907] (true: 1.0)\n"
     ]
    }
   ],
   "source": [
    "# Create a fresh model\n",
    "training_model = LinearRegressor(1)\n",
    "\n",
    "# Generate training data\n",
    "true_weight = 3.0\n",
    "true_bias = 1.0\n",
    "x_train = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)\n",
    "y_train = true_weight * x_train + true_bias + 0.1 * brainstate.random.normal(size=x_train.shape)\n",
    "\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def training_loss(x, y):\n",
    "    \"\"\"MSE loss with L2 regularization.\"\"\"\n",
    "    pred = training_model(x)\n",
    "    mse = jnp.mean((pred - y) ** 2)\n",
    "    l2 = 0.01 * (jnp.sum(training_model.weight.value ** 2) + jnp.sum(training_model.bias.value ** 2))\n",
    "    return mse + l2, {\"mse\": mse, \"l2\": l2}\n",
    "\n",
    "\n",
    "# Create gradient function\n",
    "loss_grad_fn = grad(\n",
    "    training_loss,\n",
    "    grad_states=training_model.states(brainstate.ParamState),\n",
    "    has_aux=True,\n",
    "    return_value=True,\n",
    ")\n",
    "\n",
    "# Training loop\n",
    "learning_rate = 0.1\n",
    "num_epochs = 50\n",
    "\n",
    "print(\"Training started...\")\n",
    "print(f\"Initial weight: {training_model.weight.value}\")\n",
    "print(f\"Initial bias: {training_model.bias.value}\")\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    # Compute gradients\n",
    "    grads, loss_val, aux = loss_grad_fn(x_train, y_train)\n",
    "    \n",
    "    # Update parameters (simple SGD)\n",
    "    for path, state in training_model.states(brainstate.ParamState).items():\n",
    "        grad = grads[path]\n",
    "        state.value = state.value - learning_rate * grad\n",
    "    \n",
    "    # Print progress\n",
    "    if (epoch + 1) % 10 == 0:\n",
    "        print(f\"\\nEpoch {epoch + 1}:\")\n",
    "        print(f\"  Loss: {float(loss_val):.4f}\")\n",
    "        print(f\"  MSE: {float(aux['mse']):.4f}\")\n",
    "        print(f\"  L2: {float(aux['l2']):.6f}\")\n",
    "        print(f\"  Weight: {training_model.weight.value}\")\n",
    "        print(f\"  Bias: {training_model.bias.value}\")\n",
    "\n",
    "print(\"\\nTraining completed!\")\n",
    "print(f\"Final weight: {training_model.weight.value} (true: {true_weight})\")\n",
    "print(f\"Final bias: {training_model.bias.value} (true: {true_bias})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "In this tutorial, we covered:\n",
    "\n",
    "1. **`argnums`**: Specify which function arguments to differentiate (inherited from JAX)\n",
    "2. **`grad_states`**: Specify which `State` objects should receive gradients (BrainState extension)\n",
    "3. **`ParamState`**: Standard way to mark trainable parameters in modules\n",
    "4. **Retrieving states**: Use `module.states()` or `brainstate.graph.treefy_states()`\n",
    "5. **`StateFinder`**: Discover states used in arbitrary functions\n",
    "6. **Return structures**: How `has_aux` and `return_value` affect the output\n",
    "7. **Other transforms**: `vector_grad`, `jacrev`, `jacfwd`, `jacobian`, `hessian`\n",
    "8. **Custom transforms**: Build your own using `GradientTransform`\n",
    "\n",
    "### Key Takeaways\n",
    "\n",
    "- All gradient transformations share the same signature and return structure patterns\n",
    "- `ParamState` is the standard for trainable parameters, but gradients work with any `State`\n",
    "- `StateFinder` helps discover states in arbitrary functions\n",
    "- `GradientTransform` enables custom gradient transformations while maintaining state-awareness\n",
    "- The system seamlessly integrates JAX's autodiff with BrainState's stateful computation model"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
