{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Migrating Concepts from PyTorch to BrainState"
   ],
   "id": "c53c6000519b753f"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BrainState borrows many familiar ideas from PyTorch—tensor computations,\n",
    "modules with parameters, automatic differentiation—while leaning on JAX for\n",
    "JIT compilation and functional programming. This note contrasts the key\n",
    "building blocks so you can translate existing PyTorch workflows into\n",
    "BrainState idioms quickly."
   ],
   "id": "a01ec349a4388f03"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concept map\n",
    "\n",
    "| PyTorch | BrainState | Notes |\n",
    "| --- | --- | --- |\n",
    "| `torch.Tensor` | `jax.Array` (`jnp.ndarray`) | Manipulated with `jax.numpy` semantics. |\n",
    "| `nn.Module` | `brainstate.nn.Module` | Define `State` attributes (e.g. `ParamState`, `HiddenState`). |\n",
    "| `nn.Parameter` | `brainstate.ParamState` | Holds differentiable weights; retrieved via `.states`. |\n",
    "| `autograd.grad` / `backward()` | `brainstate.transform.grad` | Explicitly select which states or arguments receive gradients. |\n",
    "| `torch.optim` optimisers | `braintools.optim` (optional) | Works on `.states(brainstate.ParamState)`. |\n",
    "| `torch.jit.script` / `torch.jit.trace` | `brainstate.transform.jit` | JIT compile pure or stateful functions; integrates with JAX. |\n",
    "| `state_dict()` / `load_state_dict()` | `brainstate.graph.treefy_states` / `brainstate.graph.update_states` | Serialize/restore state trees. |\n",
    "| Random number generators (`torch.manual_seed`) | `brainstate.random.seed` / `RandomState` | Keys are JAX PRNGs, automatically split in transforms. |"
   ],
   "id": "46aea1897895b2d3"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch baseline\n",
    "\n",
    "Consider a minimal linear regression in PyTorch:\n",
    "\n",
    "```python\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "class TorchLinear(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.linear = nn.Linear(1, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.linear(x)\n",
    "\n",
    "model = TorchLinear()\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=1e-1)\n",
    "\n",
    "for step in range(100):\n",
    "    optimizer.zero_grad()\n",
    "    preds = model(inputs)\n",
    "    loss = criterion(preds, targets)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "```\n",
    "\n",
    "BrainState follows the same logic but makes states and gradients explicit."
   ],
   "id": "c50ad19a2eae9b18"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BrainState equivalent"
   ],
   "id": "c494324cc75b57d5"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:56:40.036039Z",
     "start_time": "2025-10-11T07:56:36.562878Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import braintools.file\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "\n",
    "import brainstate\n",
    "from brainstate.transform import grad, jit\n",
    "import braintools.optim as optim\n",
    "\n",
    "# Synthetic dataset\n",
    "def make_dataset(n=64):\n",
    "    rng = np.random.default_rng(0)\n",
    "    x = rng.uniform(-1.0, 1.0, (n, 1)).astype(np.float32)\n",
    "    y = 3.0 * x + 1.0 + rng.normal(0.0, 0.1, (n, 1)).astype(np.float32)\n",
    "    return jnp.asarray(x), jnp.asarray(y)\n",
    "\n",
    "x_train, y_train = make_dataset()\n",
    "\n",
    "class LinearModel(brainstate.nn.Module):\n",
    "    def __init__(self, in_features, out_features):\n",
    "        super().__init__()\n",
    "        k1, k2 = jax.random.split(jax.random.PRNGKey(0))\n",
    "        self.weight = brainstate.ParamState(jax.random.normal(k1, (in_features, out_features)))\n",
    "        self.bias = brainstate.ParamState(jax.random.normal(k2, (out_features,)))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return x @ self.weight.value + self.bias.value\n",
    "\n",
    "model = LinearModel(1, 1)\n",
    "params = model.states(brainstate.ParamState)\n",
    "optimizer = optim.SGD(lr=1e-1)\n",
    "optimizer.register_trainable_weights(params)\n",
    "\n",
    "@jit\n",
    "def train_step(x, y):\n",
    "    def loss_fn():\n",
    "        preds = model(x)\n",
    "        return jnp.mean((preds - y) ** 2)\n",
    "\n",
    "    (grads, loss) = grad(loss_fn, grad_states=params, return_value=True)()\n",
    "    optimizer.update(grads)\n",
    "    return loss\n",
    "\n",
    "for step in range(200):\n",
    "    loss = train_step(x_train, y_train)\n",
    "    if step % 40 == 0:\n",
    "        print(f\"step {step:3d}, loss = {float(loss):.4f}\")\n",
    "\n",
    "@jit\n",
    "def predict(x):\n",
    "    return model(x)\n",
    "\n",
    "print('predictions for x=0:', predict(jnp.array([[0.0]])))"
   ],
   "id": "9ce626d44f75ba1c",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step   0, loss = 13.1320\n",
      "step  40, loss = 0.0144\n",
      "step  80, loss = 0.0097\n",
      "step 120, loss = 0.0097\n",
      "step 160, loss = 0.0097\n",
      "predictions for x=0: [[1.0059681]]\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Key observations\n",
    "\n",
    "- Parameters are stored in `ParamState` objects, so gradients are a tree keyed\n",
    "  by state paths (`params.to_flat()` mirrors `state_dict()`).\n",
    "- `grad` explicitly lists `grad_states`; argument gradients can be included via\n",
    "  `argnums` (similar to PyTorch's manual `requires_grad`).\n",
    "- Optimisers work on state trees instead of implicit parameter lists."
   ],
   "id": "91d2580cd9dbbf51"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving and loading state"
   ],
   "id": "be3105064b5d730"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:56:40.058370Z",
     "start_time": "2025-10-11T07:56:40.050338Z"
    }
   },
   "cell_type": "code",
   "source": [
    "state_tree = brainstate.graph.treefy_states(model)\n",
    "print('stored keys:', list(state_tree.to_flat().keys()))\n",
    "\n",
    "# Later (or in another process):\n",
    "restored = LinearModel(1, 1)\n",
    "brainstate.graph.update_states(restored, state_tree)\n",
    "print('restored weight:', restored.weight.value)"
   ],
   "id": "e5f19e0e7a4ac0ea",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stored keys: [('bias',), ('weight',)]\n",
      "restored weight: [[3.0168793]]\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Alternatively, you can use ``braintools.file.msgpack_save`` and ``braintools.file.msgpack_load``. ",
   "id": "1d29a44f8ed99736"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:56:40.068506Z",
     "start_time": "2025-10-11T07:56:40.062874Z"
    }
   },
   "cell_type": "code",
   "source": "braintools.file.msgpack_save('example.msgpack', model.states(brainstate.ParamState))",
   "id": "981220fb674cc6e3",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving checkpoint into example.msgpack\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Gradients with additional arguments\n",
    "\n",
    "Below, we take derivatives w.r.t. both model parameters and an explicit scalar."
   ],
   "id": "3906042a7cdeff82"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:57:11.661883Z",
     "start_time": "2025-10-11T07:57:11.634187Z"
    }
   },
   "cell_type": "code",
   "source": [
    "scale = jnp.array(0.1)\n",
    "\n",
    "def scaled_loss(alpha, inputs, targets):\n",
    "    preds = model(inputs)\n",
    "    mse = jnp.mean((preds - targets) ** 2)\n",
    "    return mse + alpha * jnp.sum(model.weight.value ** 2)\n",
    "\n",
    "(grads_state, alpha_grad), loss_val = grad(\n",
    "    scaled_loss,\n",
    "    grad_states=params,\n",
    "    argnums=0,\n",
    "    return_value=True,\n",
    ")(scale, x_train, y_train)\n",
    "\n",
    "print('loss:', float(loss_val))\n",
    "print('grad w.r.t alpha:', float(alpha_grad))\n",
    "for path, g in grads_state.items():\n",
    "    print(path, g.shape)"
   ],
   "id": "3d26a0040c97e1c1",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 0.9198333024978638\n",
      "grad w.r.t alpha: 9.101560592651367\n",
      "('bias',) (1,)\n",
      "('weight',) (1, 1)\n"
     ]
    }
   ],
   "execution_count": 5
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Random numbers\n",
    "\n",
    "BrainState wraps JAX PRNG keys. Use `brainstate.random.seed` to set the global\n",
    "seed, or instantiate a `RandomState` for module-specific randomness. Transforms\n",
    "like `vmap` and `pmap` split keys automatically per batch element."
   ],
   "id": "dfcbdfa994a7baa4"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:57:13.110058Z",
     "start_time": "2025-10-11T07:57:12.913366Z"
    }
   },
   "source": [
    "import brainstate.random as brandom\n",
    "\n",
    "brandom.seed(42)\n",
    "rs = brandom.RandomState()\n",
    "print('single sample:', rs.normal(size=(2,)))"
   ],
   "id": "ba0e691206b438e4",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "single sample: [ 0.6630465  -0.72396195]\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Debugging and JIT\n",
    "\n",
    "BrainState leans on JAX's tooling. `brainstate.transform.jit` works on stateful\n",
    "functions, while `brainstate.transform.make_jaxpr` inspects the computed graph."
   ],
   "id": "f33635ea74f7fb92"
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T07:57:15.274227Z",
     "start_time": "2025-10-11T07:57:15.243277Z"
    }
   },
   "source": [
    "from brainstate.transform import make_jaxpr\n",
    "\n",
    "jaxpr = make_jaxpr(model)\n",
    "print(jaxpr(jnp.ones((1,))))"
   ],
   "id": "ec853e79b22b7fec",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "({ \u001B[34;1mlambda \u001B[39;22m; a\u001B[35m:f32[1]\u001B[39m b\u001B[35m:f32[1,1]\u001B[39m c\u001B[35m:f32[1]\u001B[39m. \u001B[34;1mlet\n",
      "    \u001B[39;22md\u001B[35m:f32[1]\u001B[39m = dot_general[\n",
      "      dimension_numbers=(([0], [0]), ([], []))\n",
      "      preferred_element_type=float32\n",
      "    ] a b\n",
      "    e\u001B[35m:f32[1]\u001B[39m = add d c\n",
      "  \u001B[34;1min \u001B[39;22m(e, b, c) }, (ParamState(\n",
      "  value=ShapedArray(float32[1,1])\n",
      "), ParamState(\n",
      "  value=ShapedArray(float32[1])\n",
      ")))\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- Replace `nn.Module` + `nn.Parameter` with `brainstate.nn.Module` + `ParamState`.\n",
    "- Use `brainstate.transform.grad`/`jit` instead of PyTorch autograd and scripting.\n",
    "- Retrieve and update parameter trees via `graph.treefy_states` and\n",
    "  `graph.update_states`.\n",
    "- Optimisers in `braintools.optim` mirror the familiar PyTorch API, operating on\n",
    "  state dictionaries.\n",
    "\n",
    "With these substitutions most PyTorch training loops can be ported one module at\n",
    "a time to BrainState."
   ],
   "id": "4368066d82f504f9"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
