{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9f5e709d97002dcb",
   "metadata": {},
   "source": [
    "# Collective Operations\n",
    "\n",
    "The `brainstate.nn._collective_ops` module provides helpers for managing *all* modules inside a model. These functions make it easy to initialise, reset, batch, and restore stateful objects without manually traversing the module graph. This notebook introduces the core APIs with practical examples.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acff0b3212a14a16",
   "metadata": {},
   "source": [
    "## Prerequisites\n",
    "\n",
    "- Familiarity with `brainstate.nn` modules and states\n",
    "- `brainunit` installed (required by the BrainState package)\n",
    "- Basic understanding of JAX and `vmap`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "a4ebdd437565c738",
   "metadata": {},
   "outputs": [],
   "source": [
    "import brainstate\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "551d39d7e6c05198",
   "metadata": {},
   "source": [
    "## Overview of the API\n",
    "\n",
    "`brainstate.nn._collective_ops` exposes several utilities:\n",
    "\n",
    "- `call_order` — decorator that fixes the execution order of methods\n",
    "- `call_all_fns` / `vmap_call_all_fns` — call the same method on each node in a model\n",
    "- `init_all_states` / `vmap_init_all_states` — initialise state variables everywhere\n",
    "- `reset_all_states` / `vmap_reset_all_states` — reset existing states\n",
    "- `assign_state_values` — restore state values from dictionaries keyed by absolute paths\n",
    "\n",
    "We'll examine each group below.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc0728250b37d466",
   "metadata": {},
   "source": [
    "## Ordering Calls with `call_order`\n",
    "\n",
    "By default `call_all_fns` respects the order that nodes appear in the graph, but complex modules may need explicit ordering. The `call_order` decorator attaches a `call_order` attribute to any method; lower levels run first.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "5025361e5bd07383",
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderDecoder(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.encoder = brainstate.nn.Linear((16,), (32,))\n",
    "        self.decoder = brainstate.nn.Linear((32,), (16,))\n",
    "\n",
    "    @brainstate.nn.call_order(0)\n",
    "    def init_state(self):\n",
    "        self.encoder.init_state()\n",
    "        self.decoder.init_state()\n",
    "\n",
    "    @brainstate.nn.call_order(1)\n",
    "    def reset_state(self):\n",
    "        self.encoder.reset_state()\n",
    "        self.decoder.reset_state()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbd706da0ce43d9e",
   "metadata": {},
   "source": [
    "Even though `EncoderDecoder` simply forwards the calls, the decorator ensures that collective utilities honour the order when visiting child modules.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d7b4080f469aa00",
   "metadata": {},
   "source": [
    "## Initialising Every Module\n",
    "\n",
    "The simplest helper is `init_all_states`. It walks the module graph and calls `init_state` on each node. You can pass keyword arguments and exclude specific nodes when necessary.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "ea4019df2c6c213d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = brainstate.nn.Sequential(\n",
    "    brainstate.nn.Linear((10,), (32,)),\n",
    "    brainstate.nn.GELU(),\n",
    "    brainstate.nn.Dropout(prob=0.1)\n",
    ")\n",
    "\n",
    "# Initialise the entire stack at once.\n",
    "brainstate.nn.init_all_states(model, batch_size=4)\n",
    "\n",
    "# Exclude stateless nodes via a filter (here: Dropout layer).\n",
    "brainstate.nn.init_all_states(model, node_to_exclude=brainstate.nn.Dropout)\n",
    "\n",
    "# Because the function returns the target, you can chain it during construction.\n",
    "model = brainstate.nn.init_all_states(model)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e71baa8c9c850b9a",
   "metadata": {},
   "source": [
    "## Resetting State Between Sequences\n",
    "\n",
    "For recurrent models you often initialise once and then reset after processing a sequence. `reset_all_states` automates the reset pass across the entire module.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "4779da017a9abf48",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ValinaRNNCell(\n",
       "  in_size=(8,),\n",
       "  out_size=(16,),\n",
       "  num_out=16,\n",
       "  num_in=8,\n",
       "  state_initializer=ZeroInit(\n",
       "    unit=Unit(10.0^0)\n",
       "  ),\n",
       "  activation=<function relu at 0x000001863944C360>,\n",
       "  W=Linear(\n",
       "    in_size=(24,),\n",
       "    out_size=(16,),\n",
       "    w_mask=None,\n",
       "    weight=ParamState(\n",
       "      value={\n",
       "        'bias': ShapedArray(float32[16]),\n",
       "        'weight': ShapedArray(float32[24,16])\n",
       "      }\n",
       "    )\n",
       "  ),\n",
       "  h=HiddenState(\n",
       "    value=ShapedArray(float32[16])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rnn = brainstate.nn.ValinaRNNCell(num_in=8, num_out=16)\n",
    "brainstate.nn.init_all_states(rnn, batch_size=2)\n",
    "\n",
    "# ... run some inference / training ...\n",
    "\n",
    "# Reset hidden states before the next sequence.\n",
    "brainstate.nn.reset_all_states(rnn)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9e3ee23e74a7632",
   "metadata": {},
   "source": [
    "You can exclude nodes or pass additional arguments just like `init_all_states`. The decorator-driven order still applies, so you can reset buffers before hidden states if needed.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9faadda9ab022d5b",
   "metadata": {},
   "source": [
    "## Batched Initialisation with `vmap_*`\n",
    "\n",
    "To create multiple independent instances of a model (ensembles or Monte-Carlo batches), use the vectorised variants. They insert a leading axis and manage separate random keys for each copy.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "2a3ae8b184b66c43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weight shape with batching: (4, 64)\n"
     ]
    }
   ],
   "source": [
    "policy = brainstate.nn.Sequential(\n",
    "    brainstate.nn.Linear((4,), (64,)),\n",
    "    brainstate.nn.GELU(),\n",
    "    brainstate.nn.Linear((64,), (2,))\n",
    ")\n",
    "\n",
    "# Create 8 independent versions of the policy.\n",
    "brainstate.nn.vmap_init_all_states(policy, axis_size=8)\n",
    "\n",
    "# Parameters gain an extra axis on the leading dimension.\n",
    "weights = policy.layers[0].weight.value\n",
    "print('Weight shape with batching:', weights['weight'].shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "e87a6bd411f8987d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  in_size=(4,),\n",
       "  out_size=(2,),\n",
       "  layers=[\n",
       "    Linear(\n",
       "      in_size=(4,),\n",
       "      out_size=(64,),\n",
       "      w_mask=None,\n",
       "      weight=ParamState(\n",
       "        value={\n",
       "          'bias': ShapedArray(float32[64]),\n",
       "          'weight': ShapedArray(float32[4,64])\n",
       "        }\n",
       "      )\n",
       "    ),\n",
       "    GELU(approximate=False),\n",
       "    Linear(\n",
       "      in_size=(64,),\n",
       "      out_size=(2,),\n",
       "      w_mask=None,\n",
       "      weight=ParamState(\n",
       "        value={\n",
       "          'bias': ShapedArray(float32[2]),\n",
       "          'weight': ShapedArray(float32[64,2])\n",
       "        }\n",
       "      )\n",
       "    )\n",
       "  ]\n",
       ")"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# When finished with a rollout, reset all batched states at once.\n",
    "brainstate.nn.vmap_reset_all_states(policy, axis_size=8)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4c680332d131e66",
   "metadata": {},
   "source": [
    "If certain states should stay shared (for example statistics buffers), pass a `state_to_exclude` filter to `vmap_init_all_states`. Excluded states retain their original shape across the batch.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "153d665e5dacbec1",
   "metadata": {},
   "source": [
    "## Calling Arbitrary Methods Collectively\n",
    "\n",
    "`call_all_fns` is the primitive behind the init/reset helpers. You can dispatch *any* method, provided that each child module implements it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "ad3845bd533ec7c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Logged means per layer: [[Array(0.0521806, dtype=float32)], [Array(0.03177379, dtype=float32)]]\n"
     ]
    }
   ],
   "source": [
    "class LoggingLayer(brainstate.nn.Module):\n",
    "    def __init__(self, size):\n",
    "        super().__init__()\n",
    "        self.linear = brainstate.nn.Linear((size,), (size,))\n",
    "        self.logged = []\n",
    "\n",
    "    def init_state(self):\n",
    "        self.linear.init_state()\n",
    "\n",
    "    def log_stats(self):\n",
    "        weight = self.linear.weight.value['weight']\n",
    "        self.logged.append(jnp.mean(weight))\n",
    "\n",
    "net = brainstate.nn.Sequential(\n",
    "    LoggingLayer(size=8),\n",
    "    LoggingLayer(size=8)\n",
    ")\n",
    "\n",
    "brainstate.nn.init_all_states(net)\n",
    "for layer in net.layers:\n",
    "    layer.log_stats()\n",
    "\n",
    "stats = [layer.logged for layer in net.layers]\n",
    "print('Logged means per layer:', stats)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "414705df883d7db",
   "metadata": {},
   "source": [
    "Use `vmap_call_all_fns` to repeat the same method across `axis_size` independent instances. It shares the interface and filter options.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c0953d4d600ea2c",
   "metadata": {},
   "source": [
    "## Restoring States with `assign_state_values`\n",
    "\n",
    "Serialisation often involves mapping absolute state names back to objects. The `assign_state_values` helper performs the updates and returns any mismatched keys.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "c95b2810b67a57bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unexpected keys: [('layers', 0, 'weight', 'bias'), ('layers', 0, 'weight', 'weight'), ('layers', 2, 'weight', 'bias'), ('layers', 2, 'weight', 'weight')]\n",
      "Missing keys: [('layers', 0, 'weight'), ('layers', 2, 'weight')]\n"
     ]
    }
   ],
   "source": [
    "autoencoder = brainstate.nn.Sequential(\n",
    "    brainstate.nn.Linear((16,), (8,)),\n",
    "    brainstate.nn.ReLU(),\n",
    "    brainstate.nn.Linear((8,), (16,))\n",
    ")\n",
    "brainstate.nn.init_all_states(autoencoder)\n",
    "\n",
    "# Save values in a dict keyed by absolute state paths.\n",
    "state_snapshot = {}\n",
    "for path, state in autoencoder.states().items():\n",
    "    if isinstance(state.value, dict):\n",
    "        for key, value in state.value.items():\n",
    "            new_path = path + (key,)\n",
    "            state_snapshot[new_path] = value\n",
    "    else:\n",
    "        state_snapshot[path] = state.value\n",
    "\n",
    "# ... modify weights or states ...\n",
    "\n",
    "unexpected, missing = brainstate.nn.assign_state_values(autoencoder, state_snapshot)\n",
    "print('Unexpected keys:', unexpected)\n",
    "print('Missing keys:', missing)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b94a57e65dbb362",
   "metadata": {},
   "source": [
    "## Putting It All Together\n",
    "\n",
    "The snippet below demonstrates a typical lifecycle for a batched recurrent network: initialise, perform computation, reset, and restore weights.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "f9ce6f9f55b93434",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "重置状态...\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([('W', 'weight', 'bias'), ('W', 'weight', 'weight')], [('W', 'weight')])"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rnn = brainstate.nn.ValinaRNNCell(num_in=4, num_out=8)\n",
    "brainstate.nn.vmap_init_all_states(rnn,axis_size=4)\n",
    "\n",
    "# Save a snapshot of initial states.\n",
    "snapshot = {}\n",
    "for path, state in rnn.states().items():\n",
    "    if isinstance(state.value, dict):\n",
    "        for key, value in state.value.items():\n",
    "            new_path = path + (key,)\n",
    "            snapshot[new_path] = value\n",
    "    else:\n",
    "        snapshot[path] = state.value\n",
    "\n",
    "# Simulate a rollout.\n",
    "inputs = brainstate.random.randn(12, 4, 4)\n",
    "for t in range(inputs.shape[0]):\n",
    "    output = rnn(inputs[t])\n",
    "\n",
    "print(\"重置状态...\")\n",
    "brainstate.nn.vmap_reset_all_states(rnn, axis_size=4)\n",
    "# Reset before the next episode.\n",
    "unexpected, missing = brainstate.nn.assign_state_values(rnn, snapshot)\n",
    "# brainstate.nn.vmap_reset_all_states(rnn)\n",
    "\n",
    "# Restore parameters and hidden states.\n",
    "brainstate.nn.assign_state_values(rnn, snapshot)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba56fa29671e8a8b",
   "metadata": {},
   "source": [
    "## Best Practices\n",
    "\n",
    "- Always call `init_all_states` once after constructing a module.\n",
    "- Decorate stateful methods with `call_order` when their interaction matters.\n",
    "- Use filters (`node_to_exclude`, `state_to_exclude`) to fine-tune traversal.\n",
    "- Inspect the return values from `assign_state_values` to catch mismatched checkpoints.\n",
    "- Employ the vmapped helpers for ensembles but remember the added leading axis.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b207cc3ffa64e858",
   "metadata": {},
   "source": [
    "## Further Reading\n",
    "\n",
    "- [Module Basics](01_module_basics.ipynb)\n",
    "- [Recurrent Networks](04_recurrent_networks.ipynb)\n",
    "- API reference: `brainstate.nn._collective_ops`\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
