{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "31d782f4",
   "metadata": {},
   "source": [
    "# Quickstart"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c2ae882",
   "metadata": {},
   "source": [
    "This page builds and trains a small model end to end. It assumes BrainState is already\n",
    "installed (see [Installation](installation.md)) and takes about five minutes. The goal is not to\n",
    "explain every idea — the [Core](../tutorials/core/index.rst) track does that — but to show the\n",
    "shape of a BrainState program: wrap mutable arrays in `State`, build a `Module`, and train it with\n",
    "state-aware transformations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "110d5461",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:43:42.610974Z",
     "iopub.status.busy": "2026-05-30T17:43:42.610727Z",
     "iopub.status.idle": "2026-05-30T17:43:44.844203Z",
     "shell.execute_reply": "2026-05-30T17:43:44.843411Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f6c424b",
   "metadata": {},
   "source": [
    "## A model built from state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d49d894f",
   "metadata": {},
   "source": [
    "A model is a `brainstate.nn.Module` whose trainable arrays are `ParamState` objects. Here is a\n",
    "two-layer perceptron for regression. The parameters live as attributes; BrainState discovers them\n",
    "automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dd73fc85",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:43:44.846232Z",
     "iopub.status.busy": "2026-05-30T17:43:44.845947Z",
     "iopub.status.idle": "2026-05-30T17:43:48.205764Z",
     "shell.execute_reply": "2026-05-30T17:43:48.205083Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLP(\n",
       "  hidden=Linear(\n",
       "    in_size=(4,),\n",
       "    out_size=(32,),\n",
       "    weight=ParamState(\n",
       "      value={\n",
       "        'bias': ShapedArray(float32[32]),\n",
       "        'weight': ShapedArray(float32[4,32])\n",
       "      }\n",
       "    )\n",
       "  ),\n",
       "  out=Linear(\n",
       "    in_size=(32,),\n",
       "    out_size=(1,),\n",
       "    weight=ParamState(\n",
       "      value={\n",
       "        'bias': ShapedArray(float32[1]),\n",
       "        'weight': ShapedArray(float32[32,1])\n",
       "      }\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MLP(brainstate.nn.Module):\n",
    "    def __init__(self, din, dhidden, dout):\n",
    "        super().__init__()\n",
    "        self.hidden = brainstate.nn.Linear(din, dhidden)\n",
    "        self.out = brainstate.nn.Linear(dhidden, dout)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return self.out(jnp.tanh(self.hidden(x)))\n",
    "\n",
    "model = MLP(4, 32, 1)\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "453d901c",
   "metadata": {},
   "source": [
    "`model.states(brainstate.ParamState)` returns every trainable parameter, keyed by its path in the\n",
    "module tree. This collection is what we will differentiate and update."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "69e1bc18",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:43:48.207940Z",
     "iopub.status.busy": "2026-05-30T17:43:48.207572Z",
     "iopub.status.idle": "2026-05-30T17:43:48.212459Z",
     "shell.execute_reply": "2026-05-30T17:43:48.211465Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('hidden', 'weight'), ('out', 'weight')]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params = model.states(brainstate.ParamState)\n",
    "list(params.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88b14714",
   "metadata": {},
   "source": [
    "## A forward pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d6c8565",
   "metadata": {},
   "source": [
    "Calling the model runs the forward computation on a batch of inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2fbec3fa",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:43:48.214792Z",
     "iopub.status.busy": "2026-05-30T17:43:48.214518Z",
     "iopub.status.idle": "2026-05-30T17:43:48.586858Z",
     "shell.execute_reply": "2026-05-30T17:43:48.585821Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(128, 1)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = brainstate.random.randn(128, 4)\n",
    "y = jnp.sum(x ** 2, axis=-1, keepdims=True)   # a simple nonlinear target\n",
    "\n",
    "model(x).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be085b40",
   "metadata": {},
   "source": [
    "## Gradients with respect to parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cec9c3b",
   "metadata": {},
   "source": [
    "`brainstate.transform.grad` differentiates a function with respect to a collection of states —\n",
    "the parameters — rather than its positional arguments. It returns a dictionary of gradients keyed\n",
    "the same way as `params`. Pass `return_value=True` to get the loss alongside the gradients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1e7da6cf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:43:48.589601Z",
     "iopub.status.busy": "2026-05-30T17:43:48.589328Z",
     "iopub.status.idle": "2026-05-30T17:43:49.357145Z",
     "shell.execute_reply": "2026-05-30T17:43:49.356444Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "initial loss: 23.490819931030273\n",
      "gradient keys match params: True\n"
     ]
    }
   ],
   "source": [
    "def loss_fn():\n",
    "    return jnp.mean((model(x) - y) ** 2)\n",
    "\n",
    "grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()\n",
    "print('initial loss:', float(loss))\n",
    "print('gradient keys match params:', set(grads) == set(params))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "905e3967",
   "metadata": {},
   "source": [
    "## A training loop"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9224972a",
   "metadata": {},
   "source": [
    "A training step computes gradients and applies a gradient-descent update in place. Each parameter\n",
    "value can be a small PyTree (a `Linear` holds its weight and bias together), so the update walks it\n",
    "with `jax.tree.map`. Wrapping the step in `brainstate.transform.jit` compiles it; because the\n",
    "transform is state-aware, the parameter updates persist across calls automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "784a9aa4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:43:49.359908Z",
     "iopub.status.busy": "2026-05-30T17:43:49.359727Z",
     "iopub.status.idle": "2026-05-30T17:43:49.559096Z",
     "shell.execute_reply": "2026-05-30T17:43:49.558037Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step   0  loss 23.4908\n",
      "step  50  loss 4.9425\n",
      "step 100  loss 2.3319\n",
      "step 150  loss 1.1466\n",
      "step 200  loss 0.6898\n"
     ]
    }
   ],
   "source": [
    "@brainstate.transform.jit\n",
    "def train_step():\n",
    "    grads, loss = brainstate.transform.grad(loss_fn, params, return_value=True)()\n",
    "    for key in params:\n",
    "        params[key].value = jax.tree.map(lambda p, g: p - 0.05 * g, params[key].value, grads[key])\n",
    "    return loss\n",
    "\n",
    "for step in range(201):\n",
    "    loss = train_step()\n",
    "    if step % 50 == 0:\n",
    "        print(f'step {step:>3}  loss {float(loss):.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdb360e5",
   "metadata": {},
   "source": [
    "The loss falls steadily. For real training you would reach for an optimizer such as\n",
    "`braintools.optim.Adam` rather than hand-written gradient descent, but the structure is\n",
    "identical: differentiate with respect to the parameter states, then update them."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0b05ced",
   "metadata": {},
   "source": [
    "## Where to go next\n",
    "\n",
    "- [Thinking in BrainState](thinking_in_brainstate.md) — the mental model behind what you just wrote.\n",
    "- [Core tutorials](../tutorials/core/index.rst) — `State`, modules, transformations, and training in depth.\n",
    "- [Why state-based?](../concepts/why_state_based.md) — the design rationale."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
