{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8fa84d70",
   "metadata": {},
   "source": [
    "# Transformations, the Essentials"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2aea9e7a",
   "metadata": {},
   "source": [
    "`brainstate.transform` mirrors the JAX transformation API — `jit`, `grad`, `vmap` — but every\n",
    "transform is **state-aware**: it tracks the `State` objects your model reads and writes, so you\n",
    "never thread parameters and buffers through function arguments by hand.\n",
    "\n",
    "This tutorial is the gateway. It shows the three transforms you reach for daily and how they\n",
    "compose. The dedicated [transformations track](../transformations/index.rst) then covers each in\n",
    "depth — compilation internals, advanced autodiff, batched ensembles, control flow, error\n",
    "checking, and debugging."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3e3d2e34",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:41.048599Z",
     "iopub.status.busy": "2026-05-30T16:20:41.048455Z",
     "iopub.status.idle": "2026-05-30T16:20:43.192166Z",
     "shell.execute_reply": "2026-05-30T16:20:43.191348Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a3bf376",
   "metadata": {},
   "source": [
    "## Why state-aware transformations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e9b29cc",
   "metadata": {},
   "source": [
    "JAX transformations operate on **pure functions** — outputs depend only on inputs, with no side\n",
    "effects. A BrainState model is the opposite: it keeps mutable `State`, and calling it reads and\n",
    "writes that state. Hand such a model to raw `jax.jit` and the `State` write is *silently\n",
    "discarded* — the counter is recomputed from its initial value on every call and never advances:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5c9994b7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:43.194500Z",
     "iopub.status.busy": "2026-05-30T16:20:43.194039Z",
     "iopub.status.idle": "2026-05-30T16:20:43.226875Z",
     "shell.execute_reply": "2026-05-30T16:20:43.226047Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "raw jax.jit returns: [1, 1, 1, 1]\n"
     ]
    }
   ],
   "source": [
    "class Counter(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.n = brainstate.State(jnp.array(0))\n",
    "\n",
    "    def __call__(self):\n",
    "        self.n.value += 1\n",
    "        return self.n.value\n",
    "\n",
    "broken = jax.jit(Counter())\n",
    "print('raw jax.jit returns:', [int(broken()) for _ in range(4)])   # 1, 1, 1, 1 - update lost"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "228a2de7",
   "metadata": {},
   "source": [
    "`brainstate.transform.jit` understands `State`. It captures every read and write, so mutations\n",
    "persist correctly across calls:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5611eda8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:43.229695Z",
     "iopub.status.busy": "2026-05-30T16:20:43.229295Z",
     "iopub.status.idle": "2026-05-30T16:20:43.262407Z",
     "shell.execute_reply": "2026-05-30T16:20:43.261779Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 2, 3, 4]\n"
     ]
    }
   ],
   "source": [
    "counter = Counter()\n",
    "fast_counter = brainstate.transform.jit(counter)\n",
    "print([int(fast_counter()) for _ in range(4)])   # 1, 2, 3, 4 — state survives"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2a7d5ec",
   "metadata": {},
   "source": [
    "This is the rule that makes the rest of BrainState work: **wrap a model in a `brainstate`\n",
    "transform and its state is handled for you.**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c618cb0",
   "metadata": {},
   "source": [
    "## `jit`: compile once, run fast"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35c0605e",
   "metadata": {},
   "source": [
    "`jit` traces a function the first time it is called, compiles it with XLA, and reuses the\n",
    "compiled version afterwards. Use it on whole steps — a forward pass, a training step — not on\n",
    "tiny operations. We will reuse this small linear model throughout."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "538fe151",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:43.264275Z",
     "iopub.status.busy": "2026-05-30T16:20:43.264029Z",
     "iopub.status.idle": "2026-05-30T16:20:43.855153Z",
     "shell.execute_reply": "2026-05-30T16:20:43.854367Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(64, 1)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Linear(brainstate.nn.Module):\n",
    "    def __init__(self, din, dout):\n",
    "        super().__init__()\n",
    "        self.w = brainstate.ParamState(brainstate.random.randn(din, dout) * 0.1)\n",
    "        self.b = brainstate.ParamState(jnp.zeros(dout))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return x @ self.w.value + self.b.value\n",
    "\n",
    "model = Linear(3, 1)\n",
    "x = brainstate.random.randn(64, 3)\n",
    "y = brainstate.random.randn(64, 1)\n",
    "\n",
    "forward = brainstate.transform.jit(model)\n",
    "forward(x).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af9b70e6",
   "metadata": {},
   "source": [
    "## `grad`: differentiate with respect to states"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "118ea37f",
   "metadata": {},
   "source": [
    "`grad` differentiates a function with respect to a collection of `State`s — not its positional\n",
    "arguments, as in plain JAX. You pass the states to differentiate, and it returns a dictionary of\n",
    "gradients keyed by each state's path in the module tree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c05e0588",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:43.856907Z",
     "iopub.status.busy": "2026-05-30T16:20:43.856764Z",
     "iopub.status.idle": "2026-05-30T16:20:44.243498Z",
     "shell.execute_reply": "2026-05-30T16:20:44.242598Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{('b',): (1,), ('w',): (3, 1)}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params = model.states(brainstate.ParamState)\n",
    "\n",
    "def loss_fn():\n",
    "    return jnp.mean((model(x) - y) ** 2)\n",
    "\n",
    "grads = brainstate.transform.grad(loss_fn, params)()\n",
    "{key: g.shape for key, g in grads.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "264c6b78",
   "metadata": {},
   "source": [
    "The gradient keys match the parameter keys exactly, so applying an update is a simple loop. Pass\n",
    "`return_value=True` to also get the loss in the same pass, or `has_aux=True` to return extra\n",
    "diagnostics from the loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2dfdd7d0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:44.245498Z",
     "iopub.status.busy": "2026-05-30T16:20:44.245355Z",
     "iopub.status.idle": "2026-05-30T16:20:44.516445Z",
     "shell.execute_reply": "2026-05-30T16:20:44.515624Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 1.0843076705932617\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss after one step: 1.023072600364685\n"
     ]
    }
   ],
   "source": [
    "grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()\n",
    "print('loss:', float(loss))\n",
    "for key in params:\n",
    "    params[key].value -= 0.1 * grads[key]\n",
    "print('loss after one step:', float(loss_fn()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68b3ad36",
   "metadata": {},
   "source": [
    "## `vmap`: vectorize over a batch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df9dece5",
   "metadata": {},
   "source": [
    "`vmap` adds a batch dimension to a function written for a single example, turning Python-level\n",
    "looping into a single vectorized call. Here `predict_one` is written for one input row;\n",
    "`vmap` runs it across the whole batch at once."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "04047a89",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:44.519254Z",
     "iopub.status.busy": "2026-05-30T16:20:44.518984Z",
     "iopub.status.idle": "2026-05-30T16:20:44.798446Z",
     "shell.execute_reply": "2026-05-30T16:20:44.797591Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(64, 1)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def predict_one(x_row):\n",
    "    return jnp.tanh(model(x_row[None, :]))[0]\n",
    "\n",
    "predict_batch = brainstate.transform.vmap(predict_one)\n",
    "predict_batch(x).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6aa07bd3",
   "metadata": {},
   "source": [
    "Because it is state-aware, `vmap` can also map over the *states* themselves — for example to run\n",
    "an ensemble of models in parallel — via its `in_states` / `out_states` arguments. That, along\n",
    "with `vmap2`, `pmap2`, and `shard_map`, is covered in\n",
    "[vectorization](../transformations/03_vectorization.ipynb) and\n",
    "[advanced batching](../transformations/04_advanced_batching.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb2e68f8",
   "metadata": {},
   "source": [
    "## Composing transformations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f514b04",
   "metadata": {},
   "source": [
    "Transforms compose. The common pattern is `jit(grad(...))`: differentiate, then compile the\n",
    "whole gradient computation so it runs as one fused, fast kernel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4aa982d6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:44.800652Z",
     "iopub.status.busy": "2026-05-30T16:20:44.800393Z",
     "iopub.status.idle": "2026-05-30T16:20:44.877592Z",
     "shell.execute_reply": "2026-05-30T16:20:44.876583Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss trajectory: [1.0231, 0.9841, 0.9592, 0.9431, 0.9327]\n"
     ]
    }
   ],
   "source": [
    "@brainstate.transform.jit\n",
    "def train_step():\n",
    "    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()\n",
    "    for key in params:\n",
    "        params[key].value -= 0.1 * grads[key]\n",
    "    return loss\n",
    "\n",
    "losses = [float(train_step()) for _ in range(5)]\n",
    "print('loss trajectory:', [round(v, 4) for v in losses])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1af5a54",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- BrainState transforms are **state-aware** drop-ins for their JAX counterparts: they track\n",
    "  `State` reads and writes so you never thread state through arguments.\n",
    "- **`jit`** compiles a function once and reuses it; apply it to whole steps.\n",
    "- **`grad`** differentiates with respect to a collection of states and returns a gradient dict\n",
    "  keyed by state path; `return_value` and `has_aux` carry extra outputs.\n",
    "- **`vmap`** vectorizes a single-example function over a batch, and can map over states too.\n",
    "- Transforms **compose** — `jit(grad(...))` is the backbone of every training loop.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [Training and metrics](07_training_and_metrics.ipynb) — these transforms assembled into a full training loop.\n",
    "- The [transformations track](../transformations/index.rst) — `jit`, autodiff, vectorization, control flow, error handling, and debugging in depth.\n",
    "- [Transformation semantics](../../concepts/transformation_semantics.md) — how state threading works under the hood."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
