{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7734a6e7baffeea1",
   "metadata": {},
   "source": [
    "# Parallelisation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17858f3d9a756db2",
   "metadata": {},
   "source": [
    "`brainstate.transform.pmap2` mirrors `jax.pmap` while keeping BrainState `State`\n",
    "objects consistent across devices. This notebook explains how to configure the\n",
    "API, how random states behave under device parallelism, and how `pmap` reuses\n",
    "the same `StatefulMapping` infrastructure as `vmap`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8be4b21cd4e40f88",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:57.770996Z",
     "start_time": "2025-10-11T06:22:56.340531Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "from brainstate.transform import pmap2\n",
    "from brainstate.util.filter import OfType"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c5e2c3c60c26a3e",
   "metadata": {},
   "source": [
    "## Configuring devices\n",
    "\n",
    "For CPU-only demonstrations we can provision multiple devices per host by\n",
    "setting `jax_num_cpu_devices` before importing JAX. (If you are running on GPU\n",
    "or TPU you can skip this cell; the environment will report the hardware devices\n",
    "it already sees.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7f19fd9547753f9a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:23:03.484777Z",
     "start_time": "2025-10-11T06:23:03.431934Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "local device count: 8\n"
     ]
    }
   ],
   "source": [
    "jax.config.update('jax_num_cpu_devices', 8)\n",
    "print('local device count:', jax.local_device_count())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c48f16ddfad7255",
   "metadata": {},
   "source": [
    "## 1. Core arguments of `pmap`\n",
    "\n",
    "`pmap` accepts the same signature as `jax.pmap` plus BrainState-specific\n",
    "keywords (`state_in_axes`, `state_out_axes`, `unexpected_out_state_mapping`).\n",
    "Use `axis_name` to enable collectives and `devices` / `backend` when you want to\n",
    "pin the computation to specific hardware."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e6d7e01fdb5c7bc6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:24:20.631566Z",
     "start_time": "2025-10-11T06:24:20.412166Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "updated shape: (8, 4)\n",
      "final weights: [[ 1.  2.  3.  4.]\n",
      " [ 5.  6.  7.  8.]\n",
      " [ 9. 10. 11. 12.]\n",
      " [13. 14. 15. 16.]\n",
      " [17. 18. 19. 20.]\n",
      " [21. 22. 23. 24.]\n",
      " [25. 26. 27. 28.]\n",
      " [29. 30. 31. 32.]]\n"
     ]
    }
   ],
   "source": [
    "class Affine(brainstate.nn.Module):\n",
    "    def __init__(self, size):\n",
    "        super().__init__()\n",
    "        self.weight = brainstate.ParamState(jnp.ones((size,)))\n",
    "\n",
    "    def __call__(self, delta):\n",
    "        self.weight.value = self.weight.value + delta\n",
    "        return self.weight.value\n",
    "\n",
    "\n",
    "model = Affine(size=jax.local_device_count())\n",
    "axis_name = 'devices'\n",
    "\n",
    "pmapped_update = pmap2(\n",
    "    model,\n",
    "    axis_name=axis_name,\n",
    "    in_axes=0,\n",
    "    out_axes=0,\n",
    "    state_in_axes={0: OfType(brainstate.ParamState)},\n",
    "    state_out_axes={0: OfType(brainstate.ParamState)},\n",
    ")\n",
    "\n",
    "# Each device receives a different delta vector\n",
    "per_device_delta = jnp.arange(jax.local_device_count() * 4.).reshape(jax.local_device_count(), 4)\n",
    "updated = pmapped_update(per_device_delta)\n",
    "print('updated shape:', updated.shape)\n",
    "print('final weights:', model.weight.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1fc7511fae0a22e",
   "metadata": {},
   "source": [
    "### axis_size and devices\n",
    "\n",
    "`axis_size` is inferred from the device list if possible. It is useful when you\n",
    "want to simulate a smaller logical mesh than the number of physical devices.\n",
    "`devices` lets you provide an explicit list of JAX devices to map over."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "db3d36613fa7f9ce",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:25:03.866684Z",
     "start_time": "2025-10-11T06:25:03.689168Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weights after pairwise update: [[2. 2. 2. 2.]\n",
      " [0. 0. 0. 0.]]\n"
     ]
    }
   ],
   "source": [
    "logical_devices = jax.devices()[:2]\n",
    "model = Affine(size=len(logical_devices))\n",
    "\n",
    "pairwise_update = pmap2(\n",
    "    model,\n",
    "    axis_name='pair',\n",
    "    in_axes=0,\n",
    "    out_axes=0,\n",
    "    devices=logical_devices,\n",
    "    state_in_axes={0: OfType(brainstate.ParamState)},\n",
    "    state_out_axes={0: OfType(brainstate.ParamState)},\n",
    ")\n",
    "\n",
    "deltas = jnp.stack([jnp.ones((4,)), -jnp.ones((4,))], axis=0)\n",
    "pairwise_update(deltas)\n",
    "print('weights after pairwise update:', model.weight.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "772cb4f2e8f05b8e",
   "metadata": {},
   "source": [
    "### Handling static arguments and donation\n",
    "\n",
    "Most `jax.pmap` flags pass straight through: `static_broadcasted_argnums` keeps\n",
    "an argument constant across devices, while `donate_argnums` can improve memory\n",
    "usage by letting the compiler reuse buffers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9556f5a8aca34258",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:25:11.341985Z",
     "start_time": "2025-10-11T06:25:11.245213Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5], dtype=float32, weak_type=True)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@pmap2(axis_name=axis_name, in_axes=(0, None), out_axes=0)\n",
    "def add_with_scale(delta, scale):\n",
    "    return delta + scale\n",
    "\n",
    "add_with_scale(jnp.arange(jax.local_device_count()), 0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2440701a45463c4b",
   "metadata": {},
   "source": [
    "## 2. Random-number semantics\n",
    "\n",
    "As with `vmap`, BrainState splits `RandomState` keys automatically so that each\n",
    "device sees a different stream. This makes stochastic simulations reproducible\n",
    "without manual key management."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4f7b228894a6e56c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:25:28.178067Z",
     "start_time": "2025-10-11T06:25:27.630770Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([ 1.23822   , -0.2782504 , -1.9162552 , -0.21000428,  0.41403982,\n",
       "       -0.7870412 , -1.6281602 , -1.1573448 ], dtype=float32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rand_state = brainstate.random.RandomState(0)\n",
    "\n",
    "@pmap2(\n",
    "    axis_name='devices',\n",
    "    in_axes=0,\n",
    "    out_axes=0,\n",
    "    state_in_axes={0: OfType(brainstate.random.RandomState)},\n",
    "    state_out_axes={0: OfType(brainstate.random.RandomState)},\n",
    ")\n",
    "def sample_normal(scale):\n",
    "    return brainstate.random.normal(0.0, scale)\n",
    "\n",
    "per_device_scales = jnp.linspace(1.0, 2.0, jax.local_device_count())\n",
    "sample_normal(per_device_scales)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b58e58627b4415ad",
   "metadata": {},
   "source": [
    "If you need identical keys on all devices, use `jax.random` explicitly and mark\n",
    "the key input as static (`in_axes=None`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "478560605577d8a1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:25:32.558040Z",
     "start_time": "2025-10-11T06:25:32.406393Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([1.6226422, 1.8544483, 2.0862544, 2.3180604, 2.5498662, 2.7816722,\n",
       "       3.0134785, 3.2452843], dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shared_key = jax.random.PRNGKey(0)\n",
    "\n",
    "@pmap2(axis_name='devices', in_axes=(None, 0), out_axes=0)\n",
    "def sample_shared(key, scale):\n",
    "    return jax.random.normal(key, ()) * scale\n",
    "\n",
    "sample_shared(shared_key, per_device_scales)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2b16f68416a8faf",
   "metadata": {},
   "source": [
    "## 3. Relationship to `StatefulMapping`\n",
    "\n",
    "`pmap` creates a `StatefulMapping` under the hood, just like `vmap`. The wrapper\n",
    "analyzes state usage, constructs IR for the batched computation, and restores\n",
    "state values after every parallel execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ff36ae93e0cfd1d6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:25:38.032610Z",
     "start_time": "2025-10-11T06:25:38.025362Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'brainstate.transform.StatefulMapping'>\n",
      "origin fun: Affine(\n",
      "  weight=ParamState(\n",
      "    value=ShapedArray(float32[2,4])\n",
      "  )\n",
      ")\n",
      "state_in_axes: {0: OfType(<class 'brainstate.ParamState'>)}\n"
     ]
    }
   ],
   "source": [
    "parallel_mapping = pmap2(\n",
    "    model,\n",
    "    axis_name='devices',\n",
    "    in_axes=0,\n",
    "    out_axes=0,\n",
    "    state_in_axes={0: OfType(brainstate.ParamState)},\n",
    "    state_out_axes={0: OfType(brainstate.ParamState)},\n",
    ")\n",
    "\n",
    "print(type(parallel_mapping))\n",
    "print('origin fun:', parallel_mapping.origin_fun)\n",
    "print('state_in_axes:', parallel_mapping.state_in_axes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "671255a65c08792d",
   "metadata": {},
   "source": [
    "Advanced users can construct `StatefulMapping` directly, selecting their own\n",
    "mapping primitive. Below we recreate the earlier example but pass an explicit\n",
    "`jax.pmap` with custom donation settings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fc3ea270681d6517",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:26:32.633717Z",
     "start_time": "2025-10-11T06:26:32.603684Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[2., 2., 2., 2.],\n",
       "       [2., 2., 2., 2.],\n",
       "       [2., 2., 2., 2.],\n",
       "       [2., 2., 2., 2.],\n",
       "       [2., 2., 2., 2.],\n",
       "       [2., 2., 2., 2.],\n",
       "       [2., 2., 2., 2.],\n",
       "       [2., 2., 2., 2.]], dtype=float32)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from brainstate.transform import StatefulMapping\n",
    "\n",
    "model = Affine(size=jax.local_device_count())\n",
    "\n",
    "custom_pmap = StatefulMapping(\n",
    "    model,\n",
    "    in_axes=0,\n",
    "    out_axes=0,\n",
    "    state_in_axes={0: OfType(brainstate.ParamState)},\n",
    "    state_out_axes={0: OfType(brainstate.ParamState)},\n",
    "    axis_name='devices',\n",
    "    mapping_fn=lambda fun, *a, **kw: jax.pmap(fun, donate_argnums=(0,), *a, **kw),\n",
    ")\n",
    "\n",
    "custom_pmap(jnp.ones((jax.local_device_count(), 4)))\n",
    "\n",
    "model.weight.value"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23ad26f7149082d0",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- `brainstate.transform.pmap2` supports the full `jax.pmap` interface and adds\n",
    "  state-specific controls via `state_in_axes`, `state_out_axes`, and\n",
    "  `unexpected_out_state_mapping`.\n",
    "- Random states are split automatically so each device receives its own key.\n",
    "  Use `jax.random` with `in_axes=None` to broadcast a shared key instead.\n",
    "- Like `vmap`, `pmap` returns a `StatefulMapping` that identifies state axis\n",
    "  mappings and compiles the computation into a state-aware IR."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
