{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "34104b00",
   "metadata": {},
   "source": [
    "# Training and Metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08e40da5",
   "metadata": {},
   "source": [
    "This tutorial assembles the pieces you have met so far — `Module`, `ParamState`, `grad`, and\n",
    "`jit` — into a complete training loop, and introduces the tools BrainState provides around\n",
    "that loop: parameter counting, gradient clipping, optimizers, and metric tracking.\n",
    "\n",
    "We use a small synthetic classification problem so the notebook runs in seconds with no\n",
    "downloads. The mechanics are identical for real datasets — only the data source changes.\n",
    "\n",
    "You will learn to:\n",
    "\n",
    "- Inspect a model with [`count_parameters`](../../apis/nn/index.rst).\n",
    "- Drive parameter updates with a `braintools.optim` optimizer.\n",
    "- Stabilise training with [`clip_grad_norm`](../../apis/nn/index.rst).\n",
    "- Accumulate evaluation statistics with the `MultiMetric` system."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2af5ce5d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:49.079627Z",
     "iopub.status.busy": "2026-05-30T16:20:49.079399Z",
     "iopub.status.idle": "2026-05-30T16:20:53.677273Z",
     "shell.execute_reply": "2026-05-30T16:20:53.676699Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "import braintools\n",
    "from brainstate.nn import count_parameters, clip_grad_norm, MultiMetric, AverageMetric, AccuracyMetric\n",
    "from braintools.metric import softmax_cross_entropy_with_integer_labels\n",
    "\n",
    "brainstate.random.seed(42)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4a5a806",
   "metadata": {},
   "source": [
    "## A self-contained dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41d47961",
   "metadata": {},
   "source": [
    "Our task is to classify points drawn from three Gaussian clusters in an 8-dimensional space.\n",
    "`make_blobs` samples `n_per` points around each class centre using `brainstate.random` so the\n",
    "data is reproducible under the seed set above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "037f1e75",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:53.679867Z",
     "iopub.status.busy": "2026-05-30T16:20:53.679457Z",
     "iopub.status.idle": "2026-05-30T16:20:54.610772Z",
     "shell.execute_reply": "2026-05-30T16:20:54.609712Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((600, 8), (600,))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "DIM, N_CLASSES = 8, 3\n",
    "centers = brainstate.random.randn(N_CLASSES, DIM) * 2.0\n",
    "\n",
    "def make_blobs(n_per):\n",
    "    xs = [brainstate.random.randn(n_per, DIM) + centers[c] for c in range(N_CLASSES)]\n",
    "    ys = [jnp.full((n_per,), c, dtype=jnp.int32) for c in range(N_CLASSES)]\n",
    "    return jnp.concatenate(xs), jnp.concatenate(ys)\n",
    "\n",
    "x_train, y_train = make_blobs(200)\n",
    "x_test, y_test = make_blobs(50)\n",
    "x_train.shape, y_train.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29f1603d",
   "metadata": {},
   "source": [
    "Neural networks train far more reliably on inputs with comparable scales across features, so we\n",
    "**standardise** the data to zero mean and unit variance. The statistics come from the training\n",
    "set only — the test set is transformed with those same numbers, never its own, to avoid leaking\n",
    "test information into preprocessing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dac0e37c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:54.613146Z",
     "iopub.status.busy": "2026-05-30T16:20:54.612904Z",
     "iopub.status.idle": "2026-05-30T16:20:54.827472Z",
     "shell.execute_reply": "2026-05-30T16:20:54.826866Z"
    }
   },
   "outputs": [],
   "source": [
    "mean, std = x_train.mean(axis=0), x_train.std(axis=0)\n",
    "x_train = (x_train - mean) / std\n",
    "x_test = (x_test - mean) / std"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70aa0e2a",
   "metadata": {},
   "source": [
    "## Defining the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42a80363",
   "metadata": {},
   "source": [
    "A two-layer perceptron is enough for this problem. We compose two `brainstate.nn.Linear` layers\n",
    "with a ReLU nonlinearity; the final layer emits one *logit* per class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c0993dcb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:54.829540Z",
     "iopub.status.busy": "2026-05-30T16:20:54.829286Z",
     "iopub.status.idle": "2026-05-30T16:20:59.249362Z",
     "shell.execute_reply": "2026-05-30T16:20:59.248680Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLP(\n",
       "  fc1=Linear(\n",
       "    in_size=(8,),\n",
       "    out_size=(32,),\n",
       "    weight=ParamState(\n",
       "      value={\n",
       "        'bias': ShapedArray(float32[32]),\n",
       "        'weight': ShapedArray(float32[8,32])\n",
       "      }\n",
       "    )\n",
       "  ),\n",
       "  fc2=Linear(\n",
       "    in_size=(32,),\n",
       "    out_size=(3,),\n",
       "    weight=ParamState(\n",
       "      value={\n",
       "        'bias': ShapedArray(float32[3]),\n",
       "        'weight': ShapedArray(float32[32,3])\n",
       "      }\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MLP(brainstate.nn.Module):\n",
    "    def __init__(self, din, dhidden, dout):\n",
    "        super().__init__()\n",
    "        self.fc1 = brainstate.nn.Linear(din, dhidden)\n",
    "        self.fc2 = brainstate.nn.Linear(dhidden, dout)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return self.fc2(brainstate.nn.relu(self.fc1(x)))\n",
    "\n",
    "model = MLP(DIM, 32, N_CLASSES)\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c878a50",
   "metadata": {},
   "source": [
    "## Counting parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "edd5e134",
   "metadata": {},
   "source": [
    "Before training, it is worth knowing how large a model is. `count_parameters` walks the module\n",
    "tree, sums the sizes of every `ParamState`, and (by default) prints a per-parameter breakdown."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "740bbe78",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:59.251684Z",
     "iopub.status.busy": "2026-05-30T16:20:59.251321Z",
     "iopub.status.idle": "2026-05-30T16:20:59.261685Z",
     "shell.execute_reply": "2026-05-30T16:20:59.260773Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------------------+------------+\n",
      "|      Modules      | Parameters |\n",
      "+-------------------+------------+\n",
      "| ('fc1', 'weight') |    288     |\n",
      "| ('fc2', 'weight') |     99     |\n",
      "|       Total       |    387     |\n",
      "+-------------------+------------+\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "387"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_params = count_parameters(model)\n",
    "n_params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b2f8bac",
   "metadata": {},
   "source": [
    "Pass `return_table=True` to capture the formatted table as a string instead of printing it —\n",
    "useful for logging or for embedding in a report."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80f94234",
   "metadata": {},
   "source": [
    "## The optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82779b06",
   "metadata": {},
   "source": [
    "`braintools.optim` provides the usual family of optimizers (`SGD`, `Adam`, `AdamW`, `Lion`, …).\n",
    "After constructing one, register the states it is allowed to update. We collect exactly the\n",
    "`ParamState` instances — any other state types (counters, running statistics) are left untouched."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "51d9aeac",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:59.263971Z",
     "iopub.status.busy": "2026-05-30T16:20:59.263722Z",
     "iopub.status.idle": "2026-05-30T16:20:59.323613Z",
     "shell.execute_reply": "2026-05-30T16:20:59.322603Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Adam(\n",
       "  betas=(0.9, 0.999),\n",
       "  eps=1e-08,\n",
       "  amsgrad=False,\n",
       "  param_states=<braintools.optim.UniqueStateManager object at 0x778921c5c2f0>,\n",
       "  weight_decay=0.0,\n",
       "  step_count=OptimState(\n",
       "    value=ShapedArray(int32[], weak_type=True)\n",
       "  ),\n",
       "  param_groups=[\n",
       "    {\n",
       "      'params': {\n",
       "        ('fc1', 'weight'): ParamState(\n",
       "          value={\n",
       "            'bias': ShapedArray(float32[32]),\n",
       "            'weight': ShapedArray(float32[8,32])\n",
       "          }\n",
       "        ),\n",
       "        ('fc2', 'weight'): ParamState(\n",
       "          value={\n",
       "            'bias': ShapedArray(float32[3]),\n",
       "            'weight': ShapedArray(float32[32,3])\n",
       "          }\n",
       "        )\n",
       "      },\n",
       "      'lr': OptimState(\n",
       "        value=ShapedArray(float32[], weak_type=True)\n",
       "      ),\n",
       "      'weight_decay': 0.0\n",
       "    }\n",
       "  ],\n",
       "  param_groups_opt_states=[],\n",
       "  _schedulers=[],\n",
       "  _lr_scheduler=<braintools.optim.ConstantLR object at 0x778921c5cec0>,\n",
       "  _base_lr=0.01,\n",
       "  _current_lr=OptimState(...),\n",
       "  tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x778921c82200>, update=<function chain.<locals>.update_fn at 0x778921c822a0>),\n",
       "  opt_state=OptimState(\n",
       "    value=(ScaleByAdamState(count=ShapedArray(int32[]), mu={('fc1', 'weight'): {'bias': ShapedArray(float32[32]), 'weight': ShapedArray(float32[8,32])}, ('fc2', 'weight'): {'bias': ShapedArray(float32[3]), 'weight': ShapedArray(float32[32,3])}}, nu={('fc1', 'weight'): {'bias': ShapedArray(float32[32]), 'weight': ShapedArray(float32[8,32])}, ('fc2', 'weight'): {'bias': ShapedArray(float32[3]), 'weight': ShapedArray(float32[32,3])}}), ScaleByScheduleState(count=ShapedArray(int32[])))\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "optimizer = braintools.optim.Adam(lr=1e-2)\n",
    "optimizer.register_trainable_weights(model.states(brainstate.ParamState))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "596eb9c5",
   "metadata": {},
   "source": [
    "## The training step"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ffc46bc",
   "metadata": {},
   "source": [
    "A single training step does three things:\n",
    "\n",
    "1. `grad` differentiates the loss with respect to the registered parameters. With\n",
    "   `return_value=True` it returns `(grads, loss)` in one pass, so we get the loss for free.\n",
    "2. `clip_grad_norm` rescales the gradients so their global norm never exceeds `max_norm`. This\n",
    "   is cheap insurance against the occasional exploding update.\n",
    "3. `optimizer.update(grads)` applies the rescaled gradients in place.\n",
    "\n",
    "Wrapping the whole step in `brainstate.transform.jit` compiles it once and reuses the compiled\n",
    "version on every call. State reads and writes are tracked automatically across the transform\n",
    "boundary — there is no manual parameter threading."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "01aa15ba",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:59.325985Z",
     "iopub.status.busy": "2026-05-30T16:20:59.325724Z",
     "iopub.status.idle": "2026-05-30T16:20:59.330468Z",
     "shell.execute_reply": "2026-05-30T16:20:59.329696Z"
    }
   },
   "outputs": [],
   "source": [
    "params = model.states(brainstate.ParamState)\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def train_step(x, y):\n",
    "    def loss_fn():\n",
    "        logits = model(x)\n",
    "        return softmax_cross_entropy_with_integer_labels(logits, y).mean()\n",
    "\n",
    "    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()\n",
    "    grads = clip_grad_norm(grads, max_norm=1.0)\n",
    "    optimizer.update(grads)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37393326",
   "metadata": {},
   "source": [
    "## Tracking metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2078d9c",
   "metadata": {},
   "source": [
    "`MultiMetric` bundles several metrics that are updated together. Each sub-metric reads the\n",
    "keyword arguments it needs from a single `update(...)` call: `AverageMetric('loss')` consumes\n",
    "`loss=...`, while `AccuracyMetric` consumes `logits=...` and `labels=...`. The lifecycle is\n",
    "always **reset → update (per batch) → compute**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e0115a34",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:59.332529Z",
     "iopub.status.busy": "2026-05-30T16:20:59.332294Z",
     "iopub.status.idle": "2026-05-30T16:20:59.338145Z",
     "shell.execute_reply": "2026-05-30T16:20:59.337347Z"
    }
   },
   "outputs": [],
   "source": [
    "metrics = MultiMetric(\n",
    "    loss=AverageMetric('loss'),\n",
    "    accuracy=AccuracyMetric(),\n",
    ")\n",
    "\n",
    "@brainstate.transform.jit\n",
    "def eval_step(x, y):\n",
    "    logits = model(x)\n",
    "    loss = softmax_cross_entropy_with_integer_labels(logits, y).mean()\n",
    "    metrics.update(loss=loss, logits=logits, labels=y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ad2cf0c",
   "metadata": {},
   "source": [
    "## The training loop"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf267b6b",
   "metadata": {},
   "source": [
    "We iterate in mini-batches, reshuffling each epoch with `brainstate.random.permutation`. After\n",
    "every epoch we reset the metrics, run the evaluation step over the test set, and read back the\n",
    "accumulated statistics with `compute`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d83504cb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:20:59.340325Z",
     "iopub.status.busy": "2026-05-30T16:20:59.340158Z",
     "iopub.status.idle": "2026-05-30T16:21:00.710387Z",
     "shell.execute_reply": "2026-05-30T16:21:00.709488Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch  0 | test loss 0.1440 | test acc 0.973\n",
      "epoch  3 | test loss 0.0471 | test acc 0.980\n",
      "epoch  6 | test loss 0.0411 | test acc 0.980\n",
      "epoch  9 | test loss 0.0355 | test acc 0.987\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12 | test loss 0.0367 | test acc 0.987\n",
      "epoch 14 | test loss 0.0394 | test acc 0.987\n"
     ]
    }
   ],
   "source": [
    "def iter_batches(x, y, batch_size):\n",
    "    order = brainstate.random.permutation(len(x))\n",
    "    for i in range(0, len(x), batch_size):\n",
    "        idx = order[i:i + batch_size]\n",
    "        yield x[idx], y[idx]\n",
    "\n",
    "for epoch in range(15):\n",
    "    for xb, yb in iter_batches(x_train, y_train, batch_size=32):\n",
    "        train_step(xb, yb)\n",
    "\n",
    "    metrics.reset()\n",
    "    eval_step(x_test, y_test)\n",
    "    stats = metrics.compute()\n",
    "    if epoch % 3 == 0 or epoch == 14:\n",
    "        print(f\"epoch {epoch:2d} | test loss {float(stats['loss']):.4f} | \"\n",
    "              f\"test acc {float(stats['accuracy']):.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20cae043",
   "metadata": {},
   "source": [
    "## Evaluating"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "def08718",
   "metadata": {},
   "source": [
    "`metrics.compute()` returns a plain dictionary, so the final numbers are ordinary arrays you can\n",
    "log, compare, or assert on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "381fc042",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:21:00.712514Z",
     "iopub.status.busy": "2026-05-30T16:21:00.712284Z",
     "iopub.status.idle": "2026-05-30T16:21:00.716852Z",
     "shell.execute_reply": "2026-05-30T16:21:00.716079Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "final test accuracy: 0.987\n"
     ]
    }
   ],
   "source": [
    "final = metrics.compute()\n",
    "print(f\"final test accuracy: {float(final['accuracy']):.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe246765",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "A BrainState training loop is built from four composable pieces:\n",
    "\n",
    "- **`count_parameters`** — inspect model size before you commit to training.\n",
    "- **`braintools.optim`** — optimizers that update registered `ParamState`s in place.\n",
    "- **`grad` + `clip_grad_norm` + `jit`** — a compiled step that differentiates, stabilises, and applies updates with no manual state bookkeeping.\n",
    "- **`MultiMetric`** — reset/update/compute accumulation for evaluation statistics.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [Transformations, the essentials](06_transformations_essentials.ipynb) — the `jit`/`grad`/`vmap` mechanics underpinning the training step.\n",
    "- [Parameters, transforms, and regularization](05_parameters_transforms_regularization.ipynb) — constraining and penalising the parameters you train here.\n",
    "- The [transformations track](../transformations/index.rst) — autodiff, vectorization, and compilation in depth."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
