{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b4b71e89",
   "metadata": {},
   "source": [
    "# Interoperate with Flax and Equinox"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "145b4e42",
   "metadata": {},
   "source": [
    "`brainstate.interop` converts weight-bearing modules between BrainState and three other JAX\n",
    "frameworks — [Flax NNX](https://flax.readthedocs.io/), [Flax Linen], and\n",
    "[Equinox](https://docs.kidger.site/equinox/). Use it to drop a BrainState layer into an existing\n",
    "Flax model, or to pull a pretrained Flax/Equinox layer into a BrainState program. Each\n",
    "conversion is **structural and weight-preserving**: the rebuilt module produces the same output\n",
    "as the original, which we verify in every example below.\n",
    "\n",
    "[Flax Linen]: https://flax-linen.readthedocs.io/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "afe04d2f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:15:33.434129Z",
     "iopub.status.busy": "2026-05-30T17:15:33.433894Z",
     "iopub.status.idle": "2026-05-30T17:15:35.906276Z",
     "shell.execute_reply": "2026-05-30T17:15:35.905280Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import flax.nnx as nnx\n",
    "\n",
    "import brainstate\n",
    "from brainstate import interop\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb311f8a",
   "metadata": {},
   "source": [
    "## What can be converted"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb8104f1",
   "metadata": {},
   "source": [
    "Conversion operates on **registered, weight-bearing layers** and on `Sequential` stacks of\n",
    "them. `supported_layers()` lists the layer types each framework supports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c36d0764",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:15:35.908729Z",
     "iopub.status.busy": "2026-05-30T17:15:35.908370Z",
     "iopub.status.idle": "2026-05-30T17:15:36.049287Z",
     "shell.execute_reply": "2026-05-30T17:15:36.048333Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nnx      : BatchNorm1d, BatchNorm2d, BatchNorm3d, Conv1d, Conv2d, Conv3d, Dropout, Embedding, GroupNorm, LSTMCell, LayerNorm, Linear, RMSNorm\n",
      "linen    : BatchNorm1d, BatchNorm2d, BatchNorm3d, Conv1d, Conv2d, Conv3d, Dropout, Embedding, GroupNorm, LSTMCell, LayerNorm, Linear, RMSNorm\n",
      "equinox  : Conv1d, Conv2d, Conv3d, Dropout, Embedding, GroupNorm, LSTMCell, LayerNorm, Linear, RMSNorm\n"
     ]
    }
   ],
   "source": [
    "layers = interop.supported_layers()\n",
    "for framework, names in layers.items():\n",
    "    print(f'{framework:8} : {\", \".join(names)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e243ae06",
   "metadata": {},
   "source": [
    "Two consequences follow from this design, and both are deliberate:\n",
    "\n",
    "- **Activation functions are not layers.** A nonlinearity like ReLU carries no weights, so it is\n",
    "  applied functionally in a model's forward method rather than stored as a convertible layer.\n",
    "  Conversion reconstructs the *weighted* structure; you keep activations in your own forward\n",
    "  code.\n",
    "- **Custom forward logic is not convertible.** Only single registered layers and `Sequential`\n",
    "  stacks round-trip. A module with branching, skip connections, or hand-written control flow\n",
    "  cannot be mechanically rebuilt, and the converter raises an informative error rather than\n",
    "  guessing.\n",
    "\n",
    "The model below — a linear stack with a normalization layer — is exactly the convertible shape."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8a233c39",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:15:36.052181Z",
     "iopub.status.busy": "2026-05-30T17:15:36.051858Z",
     "iopub.status.idle": "2026-05-30T17:15:36.317789Z",
     "shell.execute_reply": "2026-05-30T17:15:36.316900Z"
    }
   },
   "outputs": [],
   "source": [
    "def make_model():\n",
    "    return brainstate.nn.Sequential(\n",
    "        brainstate.nn.Linear(4, 8),\n",
    "        brainstate.nn.LayerNorm(8),\n",
    "        brainstate.nn.Linear(8, 2),\n",
    "    )\n",
    "\n",
    "x = brainstate.random.randn(3, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5da823be",
   "metadata": {},
   "source": [
    "## Flax NNX"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79aee574",
   "metadata": {},
   "source": [
    "`to_nnx` builds an NNX module; it needs an `nnx.Rngs` to construct the foreign layers (their\n",
    "weights are then overwritten with the converted values). `from_nnx` goes the other way. We check\n",
    "that outputs match in both directions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5e6e16ac",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:15:36.320156Z",
     "iopub.status.busy": "2026-05-30T17:15:36.319954Z",
     "iopub.status.idle": "2026-05-30T17:15:40.756800Z",
     "shell.execute_reply": "2026-05-30T17:15:40.755752Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "to_nnx output matches  : True\n",
      "from_nnx output matches: True\n"
     ]
    }
   ],
   "source": [
    "model = make_model()\n",
    "reference = model(x)\n",
    "\n",
    "# BrainState -> NNX\n",
    "nnx_model = interop.to_nnx(model, rngs=nnx.Rngs(0))\n",
    "print('to_nnx output matches  :', bool(jnp.allclose(reference, nnx_model(x), atol=1e-5)))\n",
    "\n",
    "# NNX -> BrainState\n",
    "back = interop.from_nnx(nnx_model)\n",
    "print('from_nnx output matches:', bool(jnp.allclose(reference, back(x), atol=1e-5)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a435e94",
   "metadata": {},
   "source": [
    "## Flax Linen"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdc0bc14",
   "metadata": {},
   "source": [
    "Linen separates definition from parameters, so `to_linen` returns a `(module, params)` pair: call\n",
    "`module.apply(params, x)` to run it. `from_linen` takes both back and rebuilds the BrainState\n",
    "model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1f6bc1b8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:15:40.758950Z",
     "iopub.status.busy": "2026-05-30T17:15:40.758568Z",
     "iopub.status.idle": "2026-05-30T17:15:41.079676Z",
     "shell.execute_reply": "2026-05-30T17:15:41.079001Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "to_linen output matches  : True\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "from_linen output matches: True\n"
     ]
    }
   ],
   "source": [
    "model = make_model()\n",
    "reference = model(x)\n",
    "\n",
    "# BrainState -> Linen\n",
    "linen_module, params = interop.to_linen(model)\n",
    "print('to_linen output matches  :', bool(jnp.allclose(reference, linen_module.apply(params, x), atol=1e-5)))\n",
    "\n",
    "# Linen -> BrainState\n",
    "back = interop.from_linen(linen_module, params)\n",
    "print('from_linen output matches:', bool(jnp.allclose(reference, back(x), atol=1e-5)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b35aefbf",
   "metadata": {},
   "source": [
    "## Equinox"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11207a28",
   "metadata": {},
   "source": [
    "`to_equinox` accepts an optional PRNG `key` for constructing the foreign layers. Equinox modules\n",
    "operate on a single example, so we `jax.vmap` over the batch when calling the exported model;\n",
    "`from_equinox` returns a batched BrainState model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5632b66a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:15:41.082251Z",
     "iopub.status.busy": "2026-05-30T17:15:41.081983Z",
     "iopub.status.idle": "2026-05-30T17:15:42.097422Z",
     "shell.execute_reply": "2026-05-30T17:15:42.096420Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "to_equinox output matches  : True\n",
      "from_equinox output matches: True\n"
     ]
    }
   ],
   "source": [
    "model = make_model()\n",
    "reference = model(x)\n",
    "\n",
    "# BrainState -> Equinox  (call per-example, so vmap over the batch)\n",
    "eqx_model = interop.to_equinox(model, key=jax.random.PRNGKey(0))\n",
    "eqx_out = jax.vmap(eqx_model)(x)\n",
    "print('to_equinox output matches  :', bool(jnp.allclose(reference, eqx_out, atol=1e-5)))\n",
    "\n",
    "# Equinox -> BrainState\n",
    "back = interop.from_equinox(eqx_model)\n",
    "print('from_equinox output matches:', bool(jnp.allclose(reference, back(x), atol=1e-5)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7791b107",
   "metadata": {},
   "source": [
    "## Spatial layers need a sample shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d16ada65",
   "metadata": {},
   "source": [
    "Importing a convolution or spatial normalization requires the input shape, because BrainState\n",
    "materializes the layer's input size up front. Pass `sample_input` — a single unbatched example\n",
    "or its shape — to `from_nnx` / `from_linen` / `from_equinox` for those layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b046686a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T17:15:42.099616Z",
     "iopub.status.busy": "2026-05-30T17:15:42.099392Z",
     "iopub.status.idle": "2026-05-30T17:15:41.768121Z",
     "shell.execute_reply": "2026-05-30T17:15:41.767220Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "converted conv output shape: (2, 8, 8, 4)\n"
     ]
    }
   ],
   "source": [
    "conv = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), rngs=nnx.Rngs(0))\n",
    "bst_conv = interop.from_nnx(conv, sample_input=(8, 8, 3))   # H, W, C (no batch dim)\n",
    "image = brainstate.random.randn(2, 8, 8, 3)\n",
    "print('converted conv output shape:', bst_conv(image).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ef743e4",
   "metadata": {},
   "source": [
    "## Extending the registry"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53ddd47c",
   "metadata": {},
   "source": [
    "`register_layer_mapping` lets you teach the converter about a layer type it does not handle out\n",
    "of the box, by supplying the to/from conversion functions. This is the extension point for\n",
    "custom layers; the built-in mappings use the same mechanism."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d49afb47",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- `brainstate.interop` converts weight-bearing layers and `Sequential` stacks between BrainState\n",
    "  and Flax NNX, Flax Linen, and Equinox, preserving weights.\n",
    "- Directions: `to_nnx` / `from_nnx`, `to_linen` / `from_linen`, `to_equinox` / `from_equinox`.\n",
    "  `to_nnx` needs `rngs=`, `to_linen` returns `(module, params)`, and `to_equinox` accepts `key=`\n",
    "  and yields a per-example module.\n",
    "- Only registered layers convert: activations stay in your forward code, and modules with custom\n",
    "  forward logic are rejected with a clear error.\n",
    "- Importing spatial layers (`Conv`, spatial `BatchNorm`) requires `sample_input=`.\n",
    "- `supported_layers()` lists what is covered; `register_layer_mapping()` extends it.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [Common layers](../tutorials/core/03_common_layers.ipynb) — the BrainState layers that convert cleanly.\n",
    "- The [`brainstate.interop` API reference](../apis/interop.rst)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
