{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vectorization\n",
    "\n",
    "Vectorization is a fundamental technique for efficient computation in machine learning and scientific computing. BrainState provides `brainstate.transform.vmap` as a state-aware wrapper around JAX's `jax.vmap`, enabling seamless vectorization of stateful computations.\n",
    "\n",
    "This tutorial covers:\n",
    "\n",
    "1. **Basic usage of `vmap2`** with detailed parameter explanations and examples\n",
    "2. **Random number semantics** and how `vmap2` automatically handles `RandomState`\n",
    "3. **Understanding `StatefulMapping`**, the underlying abstraction that powers `vmap2`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:12.852853Z",
     "start_time": "2025-10-11T06:22:11.470123Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "from brainstate.transform import vmap2\n",
    "from brainstate.util.filter import OfType"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Basic Usage: Understanding `vmap2` Parameters\n",
    "\n",
    "### 1.1 The `in_axes` Parameter\n",
    "\n",
    "The `in_axes` parameter controls how batch dimensions are mapped over function arguments. It works identically to `jax.vmap`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.070202Z",
     "start_time": "2025-10-11T06:22:12.863871Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input shape: (4,)\n",
      "Output: [ 1.  4.  9. 16.]\n",
      "Output shape: (4,)\n"
     ]
    }
   ],
   "source": [
    "# Example 1: Single scalar-to-scalar function\n",
    "def square(x):\n",
    "    return x ** 2\n",
    "\n",
    "\n",
    "# Vectorize over the first axis (default)\n",
    "vmap_square = vmap2(square, in_axes=0)\n",
    "\n",
    "xs = jnp.array([1.0, 2.0, 3.0, 4.0])\n",
    "print(\"Input shape:\", xs.shape)\n",
    "print(\"Output:\", vmap_square(xs))\n",
    "print(\"Output shape:\", vmap_square(xs).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.187431Z",
     "start_time": "2025-10-11T06:22:13.076837Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batched x: [1. 2. 3.]\n",
      "Single weight: 2.0\n",
      "Result: [2. 4. 6.]\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Multiple arguments with different in_axes\n",
    "def weighted_sum(x, weight):\n",
    "    \"\"\"Compute weighted sum: x * weight\"\"\"\n",
    "    return x * weight\n",
    "\n",
    "\n",
    "# Vectorize over x (batch), but broadcast weight (single value)\n",
    "vmap_weighted = vmap2(weighted_sum, in_axes=(0, None))\n",
    "\n",
    "batch_x = jnp.array([1.0, 2.0, 3.0])\n",
    "single_weight = 2.0\n",
    "\n",
    "result = vmap_weighted(batch_x, single_weight)\n",
    "print(\"Batched x:\", batch_x)\n",
    "print(\"Single weight:\", single_weight)\n",
    "print(\"Result:\", result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.270604Z",
     "start_time": "2025-10-11T06:22:13.189678Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input shapes: (4, 3, 2) (4, 2)\n",
      "Output shape: (4, 3)\n"
     ]
    }
   ],
   "source": [
    "# Example 3: Vectorizing along different axes\n",
    "def matrix_vector_product(matrix, vector):\n",
    "    return matrix @ vector\n",
    "\n",
    "\n",
    "# Batch of matrices: shape (batch, m, n)\n",
    "# Batch of vectors: shape (batch, n)\n",
    "batch_matrices = jnp.ones((4, 3, 2))  # 4 matrices of shape (3, 2)\n",
    "batch_vectors = jnp.ones((4, 2))  # 4 vectors of shape (2,)\n",
    "\n",
    "# Map over the first axis of both arguments\n",
    "vmap_matmul = vmap2(matrix_vector_product, in_axes=(0, 0))\n",
    "result = vmap_matmul(batch_matrices, batch_vectors)\n",
    "\n",
    "print(\"Input shapes:\", batch_matrices.shape, batch_vectors.shape)\n",
    "print(\"Output shape:\", result.shape)  # (4, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 The `out_axes` Parameter\n",
    "\n",
    "The `out_axes` parameter controls where the batch dimension appears in the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.425489Z",
     "start_time": "2025-10-11T06:22:13.277099Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "out_axes=0, shape: (2, 3)\n",
      "[[1. 2. 3.]\n",
      " [2. 4. 6.]]\n",
      "\n",
      "out_axes=1, shape: (3, 2)\n",
      "[[1. 2.]\n",
      " [2. 4.]\n",
      " [3. 6.]]\n"
     ]
    }
   ],
   "source": [
    "def create_vector(scalar):\n",
    "    \"\"\"Create a 3D vector from a scalar.\"\"\"\n",
    "    return jnp.array([scalar, scalar * 2, scalar * 3])\n",
    "\n",
    "\n",
    "# Default: batch dimension at axis 0\n",
    "vmap_default = vmap2(create_vector, in_axes=0, out_axes=0)\n",
    "result_axis0 = vmap_default(jnp.array([1.0, 2.0]))\n",
    "print(\"out_axes=0, shape:\", result_axis0.shape)  # (2, 3)\n",
    "print(result_axis0)\n",
    "\n",
    "# Batch dimension at axis 1\n",
    "vmap_axis1 = vmap2(create_vector, in_axes=0, out_axes=1)\n",
    "result_axis1 = vmap_axis1(jnp.array([1.0, 2.0]))\n",
    "print(\"\\nout_axes=1, shape:\", result_axis1.shape)  # (3, 2)\n",
    "print(result_axis1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.3 The `axis_name` Parameter\n",
    "\n",
    "The `axis_name` parameter allows you to name the mapped axis, enabling collective operations like `jax.lax.pmean`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.509323Z",
     "start_time": "2025-10-11T06:22:13.432392Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: [1. 2. 3. 4.]\n",
      "Batch mean: 2.5\n",
      "Normalized: [-1.5 -0.5  0.5  1.5]\n",
      "New mean: 0.0\n"
     ]
    }
   ],
   "source": [
    "def normalize_batch(x):\n",
    "    \"\"\"Normalize by subtracting the batch mean.\"\"\"\n",
    "    # Compute mean across the 'batch' axis\n",
    "    batch_mean = jax.lax.pmean(x, axis_name='batch')\n",
    "    return x - batch_mean\n",
    "\n",
    "\n",
    "# Name the mapped axis as 'batch'\n",
    "vmap_normalize = vmap2(normalize_batch, in_axes=0, axis_name='batch')\n",
    "\n",
    "batch_data = jnp.array([1.0, 2.0, 3.0, 4.0])\n",
    "normalized = vmap_normalize(batch_data)\n",
    "\n",
    "print(\"Input:\", batch_data)\n",
    "print(\"Batch mean:\", jnp.mean(batch_data))\n",
    "print(\"Normalized:\", normalized)\n",
    "print(\"New mean:\", jnp.mean(normalized))  # Should be ~0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.4 The `axis_size` Parameter\n",
    "\n",
    "The `axis_size` parameter explicitly specifies the size of the mapped axis. It's optional when the size can be inferred from arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.625184Z",
     "start_time": "2025-10-11T06:22:13.516075Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated sequences:\n",
      "[[0 1 2]\n",
      " [0 1 2]\n",
      " [0 1 2]\n",
      " [0 1 2]\n",
      " [0 1 2]]\n",
      "Shape: (5, 3)\n"
     ]
    }
   ],
   "source": [
    "def generate_sequence(unused=None):\n",
    "    \"\"\"Generate a sequence (for demonstration).\"\"\"\n",
    "    return jnp.arange(3)\n",
    "\n",
    "\n",
    "# When all inputs are static (None in in_axes), we must specify axis_size\n",
    "vmap_generate = vmap2(generate_sequence, in_axes=None, axis_size=5)\n",
    "\n",
    "result = vmap_generate()\n",
    "print(\"Generated sequences:\")\n",
    "print(result)\n",
    "print(\"Shape:\", result.shape)  # (5, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.5 State-Aware Parameters: `state_in_axes` and `state_out_axes`\n",
    "\n",
    "These are BrainState-specific parameters that control how `State` objects are batched."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.670425Z",
     "start_time": "2025-10-11T06:22:13.632190Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deltas: [1. 2. 3. 4.]\n",
      "Counts: [1. 2. 3. 4.]\n",
      "Final counter value: [1. 2. 3. 4.]\n"
     ]
    }
   ],
   "source": [
    "class Counter(brainstate.nn.Module):\n",
    "    \"\"\"A simple counter using ShortTermState.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.count = brainstate.ShortTermState(jnp.zeros(4))\n",
    "\n",
    "    def __call__(self, delta):\n",
    "        \"\"\"Increment counter by delta.\"\"\"\n",
    "        self.count.value = self.count.value + delta\n",
    "        return self.count.value\n",
    "\n",
    "\n",
    "counter = Counter()\n",
    "\n",
    "# Vectorize with state batching\n",
    "vmap_counter = vmap2(\n",
    "    counter,\n",
    "    in_axes=0,  # Batch over input deltas\n",
    "    out_axes=0,  # Batch over output counts\n",
    "    # Batch the counter state along axis 0\n",
    "    state_in_axes={0: OfType(brainstate.ShortTermState)},\n",
    "    state_out_axes={0: OfType(brainstate.ShortTermState)},\n",
    ")\n",
    "\n",
    "deltas = jnp.array([1.0, 2.0, 3.0, 4.0])\n",
    "counts = vmap_counter(deltas)\n",
    "\n",
    "print(\"Deltas:\", deltas)\n",
    "print(\"Counts:\", counts)\n",
    "print(\"Final counter value:\", counter.count.value)  # Sum of deltas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.6 Working with Module States\n",
    "\n",
    "When working with `nn.Module`, states are typically shared (broadcast) across the batch by default."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.781274Z",
     "start_time": "2025-10-11T06:22:13.676250Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input shape: (4, 3)\n",
      "Output shape: (4, 2)\n",
      "Output:\n",
      "[[3. 3.]\n",
      " [3. 3.]\n",
      " [3. 3.]\n",
      " [3. 3.]]\n"
     ]
    }
   ],
   "source": [
    "class LinearLayer(brainstate.nn.Module):\n",
    "    \"\"\"Simple linear layer.\"\"\"\n",
    "\n",
    "    def __init__(self, in_features, out_features):\n",
    "        super().__init__()\n",
    "        # Parameters are ParamState\n",
    "        self.weight = brainstate.ParamState(jnp.ones((in_features, out_features)))\n",
    "        self.bias = brainstate.ParamState(jnp.zeros((out_features,)))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return x @ self.weight.value + self.bias.value\n",
    "\n",
    "\n",
    "layer = LinearLayer(3, 2)\n",
    "\n",
    "# Vectorize over batch of inputs\n",
    "# Parameters are shared (broadcast) across the batch\n",
    "vmap_layer = vmap2(layer, in_axes=0, out_axes=0)\n",
    "\n",
    "batch_inputs = jnp.ones((4, 3))  # Batch of 4 inputs\n",
    "batch_outputs = vmap_layer(batch_inputs)\n",
    "\n",
    "print(\"Input shape:\", batch_inputs.shape)  # (4, 3)\n",
    "print(\"Output shape:\", batch_outputs.shape)  # (4, 2)\n",
    "print(\"Output:\")\n",
    "print(batch_outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.7 The `unexpected_out_state_mapping` Parameter\n",
    "\n",
    "This parameter controls behavior when a state is written but not covered by `state_out_axes`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.876485Z",
     "start_time": "2025-10-11T06:22:13.788279Z"
    }
   },
   "outputs": [],
   "source": [
    "temp_state = brainstate.ShortTermState(jnp.zeros(3))\n",
    "write_state = brainstate.LongTermState(jnp.asarray(0.))\n",
    "\n",
    "\n",
    "def update_temp(x):\n",
    "    \"\"\"Function that writes to a state.\"\"\"\n",
    "    temp_state.value = temp_state.value + x\n",
    "    write_state.value = temp_state.value\n",
    "    return temp_state.value\n",
    "\n",
    "\n",
    "# Example 1: Properly specify state_out_axes\n",
    "vmap_proper = vmap2(\n",
    "    update_temp,\n",
    "    in_axes=0,\n",
    "    state_in_axes={0: OfType(brainstate.ShortTermState)},\n",
    "    state_out_axes={0: OfType(brainstate.ShortTermState)},\n",
    "    unexpected_out_state_mapping='ignore',  # Default\n",
    ")\n",
    "\n",
    "try:\n",
    "    result = vmap_proper(jnp.array([1.0, 2.0, 3.0]))\n",
    "except Exception as e:\n",
    "    print(e)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:13.916343Z",
     "start_time": "2025-10-11T06:22:13.883401Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before vmapping, original write state value: 0.0\n",
      "With 'ignore' policy: [1. 2. 3.]\n",
      "With 'ignore' policy, write state value after vmapping: [1. 2. 3.]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Example 2: Using 'ignore' to allow unexpected states\n",
    "temp_state2 = brainstate.ShortTermState(jnp.array(0.0))\n",
    "write_state2 = brainstate.LongTermState(jnp.asarray(0.))\n",
    "\n",
    "\n",
    "def update_temp2(x):\n",
    "    temp_state2.value = temp_state2.value + x\n",
    "    write_state2.value = temp_state2.value\n",
    "    return temp_state2.value\n",
    "\n",
    "\n",
    "print('Before vmapping, original write state value:', write_state2.value)\n",
    "\n",
    "vmap_ignore = vmap2(\n",
    "    update_temp2,\n",
    "    in_axes=0,\n",
    "    # Note: not specifying state_in_axes/state_out_axes\n",
    "    unexpected_out_state_mapping='ignore',\n",
    ")\n",
    "\n",
    "result2 = vmap_ignore(jnp.array([1.0, 2.0, 3.0]))\n",
    "print(\"With 'ignore' policy:\", result2)\n",
    "print(\"With 'ignore' policy, write state value after vmapping:\", write_state2.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Random Number Semantics\n",
    "\n",
    "### 2.1 Automatic Key Splitting for `RandomState`\n",
    "\n",
    "**Important**: `brainstate.transform.vmap` automatically splits PRNG keys for `brainstate.random.RandomState`, ensuring each batch element receives a unique random key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:14.231974Z",
     "start_time": "2025-10-11T06:22:13.924741Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scales: [1. 2. 3. 4.]\n",
      "Samples: [-1.0413289 -1.4796011  2.222502   6.412178 ]\n",
      "\n",
      "Note: Each sample is different (independent random key per batch element)\n"
     ]
    }
   ],
   "source": [
    "# Reset random state\n",
    "brainstate.random.seed(42)\n",
    "\n",
    "\n",
    "def sample_normal(scale):\n",
    "    \"\"\"Sample from a normal distribution.\"\"\"\n",
    "    return brainstate.random.normal(0.0, scale)\n",
    "\n",
    "\n",
    "# Vectorize the sampling function\n",
    "vmap_sample = vmap2(\n",
    "    sample_normal,\n",
    "    in_axes=0,\n",
    "    # RandomState is automatically handled!\n",
    "    # state_in_axes={0: OfType(brainstate.random.RandomState)},\n",
    "    # state_out_axes={0: OfType(brainstate.random.RandomState)},\n",
    ")\n",
    "\n",
    "scales = jnp.array([1.0, 2.0, 3.0, 4.0])\n",
    "samples = vmap_sample(scales)\n",
    "\n",
    "print(\"Scales:\", scales)\n",
    "print(\"Samples:\", samples)\n",
    "print(\"\\nNote: Each sample is different (independent random key per batch element)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:14.550898Z",
     "start_time": "2025-10-11T06:22:14.234227Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Means: [0. 1. 2.]\n",
      "Results: [1.063001  2.0858884 3.2780576]\n",
      "\n",
      "Each batch element uses independent random keys for both operations\n"
     ]
    }
   ],
   "source": [
    "# Example 2: Multiple random operations\n",
    "brainstate.random.seed(123)\n",
    "\n",
    "\n",
    "def sample_multiple(mean):\n",
    "    \"\"\"Sample multiple random numbers.\"\"\"\n",
    "    sample1 = brainstate.random.uniform(0.0, 1.0)\n",
    "    sample2 = brainstate.random.normal(mean, 1.0)\n",
    "    return sample1 + sample2\n",
    "\n",
    "\n",
    "vmap_multiple = vmap2(sample_multiple, in_axes=0)\n",
    "\n",
    "means = jnp.array([0.0, 1.0, 2.0])\n",
    "results = vmap_multiple(means)\n",
    "\n",
    "print(\"Means:\", means)\n",
    "print(\"Results:\", results)\n",
    "print(\"\\nEach batch element uses independent random keys for both operations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Controlling Random Keys: Using JAX's Random API\n",
    "\n",
    "If you need **shared random keys** across batch elements (same random numbers), use `jax.random` APIs and set `in_axes=None` for the key."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:14.712300Z",
     "start_time": "2025-10-11T06:22:14.558914Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Samples with shared key: [1.6226422 3.2452843 4.8679266 6.4905686]\n",
      "Notice: All samples use the same base random number, just scaled differently\n",
      "\n",
      "Samples with unique keys: [ 1.0040143 -4.8849115  3.8869078 -2.4877744]\n",
      "Notice: Each sample is independent\n"
     ]
    }
   ],
   "source": [
    "def sample_with_jax_key(key, scale):\n",
    "    \"\"\"Sample using JAX's random API.\"\"\"\n",
    "    return jax.random.normal(key, ()) * scale\n",
    "\n",
    "\n",
    "# Shared key across all batch elements\n",
    "vmap_shared_key = vmap2(\n",
    "    sample_with_jax_key,\n",
    "    in_axes=(None, 0),  # key is None (broadcast), scale is batched\n",
    ")\n",
    "\n",
    "shared_key = jax.random.PRNGKey(0)\n",
    "scales = jnp.array([1.0, 2.0, 3.0, 4.0])\n",
    "samples_shared = vmap_shared_key(shared_key, scales)\n",
    "\n",
    "print(\"Samples with shared key:\", samples_shared)\n",
    "print(\"Notice: All samples use the same base random number, just scaled differently\")\n",
    "\n",
    "\n",
    "# Compare with unique keys per batch element\n",
    "def sample_with_unique_keys(key, scale):\n",
    "    return jax.random.normal(key, ()) * scale\n",
    "\n",
    "\n",
    "vmap_unique_keys = vmap2(\n",
    "    sample_with_unique_keys,\n",
    "    in_axes=(0, 0),  # Both key and scale are batched\n",
    ")\n",
    "\n",
    "# Split key into batch\n",
    "key = jax.random.PRNGKey(0)\n",
    "keys = jax.random.split(key, len(scales))\n",
    "samples_unique = vmap_unique_keys(keys, scales)\n",
    "\n",
    "print(\"\\nSamples with unique keys:\", samples_unique)\n",
    "print(\"Notice: Each sample is independent\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Practical Example: Dropout with Reproducibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:14.905945Z",
     "start_time": "2025-10-11T06:22:14.718305Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original data:\n",
      "[[1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]]\n",
      "\n",
      "After dropout:\n",
      "[[0.        1.4285715 1.4285715 0.        1.4285715]\n",
      " [1.4285715 1.4285715 1.4285715 1.4285715 0.       ]\n",
      " [1.4285715 0.        1.4285715 0.        0.       ]\n",
      " [0.        1.4285715 1.4285715 1.4285715 0.       ]]\n",
      "\n",
      "Note: Each row has a different dropout pattern\n"
     ]
    }
   ],
   "source": [
    "class Dropout(brainstate.nn.Module):\n",
    "    \"\"\"Dropout layer using BrainState random.\"\"\"\n",
    "\n",
    "    def __init__(self, rate=0.5):\n",
    "        super().__init__()\n",
    "        self.rate = rate\n",
    "\n",
    "    def __call__(self, x, training=True):\n",
    "        if not training:\n",
    "            return x\n",
    "        # Each call gets independent random mask\n",
    "        keep_mask = brainstate.random.uniform(0.0, 1.0, x.shape) > self.rate\n",
    "        return jnp.where(keep_mask, x / (1 - self.rate), 0.0)\n",
    "\n",
    "\n",
    "brainstate.random.seed(456)\n",
    "dropout = Dropout(rate=0.3)\n",
    "\n",
    "# Vectorize dropout application\n",
    "vmap_dropout = vmap2(\n",
    "    lambda x: dropout(x, training=True),\n",
    "    in_axes=0,\n",
    ")\n",
    "\n",
    "batch_data = jnp.ones((4, 5))  # 4 samples, 5 features\n",
    "dropped = vmap_dropout(batch_data)\n",
    "\n",
    "print(\"Original data:\")\n",
    "print(batch_data)\n",
    "print(\"\\nAfter dropout:\")\n",
    "print(dropped)\n",
    "print(\"\\nNote: Each row has a different dropout pattern\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Under the Hood: `StatefulMapping`\n",
    "\n",
    "`brainstate.transform.vmap` is actually a thin wrapper around `brainstate.transform.StatefulMapping`, which provides the core state-aware mapping functionality.\n",
    "\n",
    "### 3.1 Understanding the Architecture\n",
    "\n",
    "`StatefulMapping` performs several key operations:\n",
    "\n",
    "1. **State Discovery**: Identifies all `State` objects accessed by the function\n",
    "2. **In/Out Axis Mapping**: Determines which states are batched and along which axes\n",
    "3. **IR Compilation**: Compiles the function to JAX's intermediate representation (Jaxpr)\n",
    "4. **State Management**: Manages state values before and after execution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:14.918588Z",
     "start_time": "2025-10-11T06:22:14.913951Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type: <class 'brainstate.transform.StatefulMapping'>\n",
      "Origin function: <function accumulate at 0x0000024533179760>\n",
      "in_axes: 0\n",
      "out_axes: 0\n",
      "state_in_axes: {0: OfType(<class 'brainstate.ShortTermState'>)}\n",
      "state_out_axes: {}\n",
      "axis_name: None\n",
      "axis_size: 4\n"
     ]
    }
   ],
   "source": [
    "# Example: Inspecting StatefulMapping\n",
    "accumulator = brainstate.ShortTermState(jnp.zeros(4))\n",
    "\n",
    "\n",
    "def accumulate(x):\n",
    "    accumulator.value = accumulator.value + x\n",
    "    return accumulator.value\n",
    "\n",
    "\n",
    "# Create a StatefulMapping\n",
    "mapped_accumulate = vmap2(\n",
    "    accumulate,\n",
    "    in_axes=0,\n",
    "    out_axes=0,\n",
    "    axis_size=4,\n",
    "    state_in_axes={0: OfType(brainstate.ShortTermState)},\n",
    ")\n",
    "\n",
    "# Inspect the StatefulMapping object\n",
    "print(\"Type:\", type(mapped_accumulate))\n",
    "print(\"Origin function:\", mapped_accumulate.origin_fun)\n",
    "print(\"in_axes:\", mapped_accumulate.in_axes)\n",
    "print(\"out_axes:\", mapped_accumulate.out_axes)\n",
    "print(\"state_in_axes:\", mapped_accumulate.state_in_axes)\n",
    "print(\"state_out_axes:\", mapped_accumulate.state_out_axes)\n",
    "print(\"axis_name:\", mapped_accumulate.axis_name)\n",
    "print(\"axis_size:\", mapped_accumulate.axis_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 Compilation and Caching\n",
    "\n",
    "`StatefulMapping` compiles the function and caches:\n",
    "- The Jaxpr (JAX intermediate representation)\n",
    "- State traces (which states are accessed)\n",
    "- Batch axis mappings\n",
    "\n",
    "This compilation happens lazily on first call."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:14.941345Z",
     "start_time": "2025-10-11T06:22:14.932594Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before first call, count: 0\n",
      "After first call, count: 1 (compilation trace)\n",
      "After second call, count: 0 (no recompilation)\n",
      "\n",
      "Results:\n",
      "First: [2. 4. 6.]\n",
      "Second: [ 8. 10. 12.]\n"
     ]
    }
   ],
   "source": [
    "# Example: Observing compilation\n",
    "call_count = [0]\n",
    "\n",
    "\n",
    "def counting_function(x):\n",
    "    call_count[0] += 1\n",
    "    return x * 2\n",
    "\n",
    "\n",
    "vmap_counting = vmap2(counting_function, in_axes=0)\n",
    "\n",
    "# First call: triggers compilation\n",
    "print(\"Before first call, count:\", call_count[0])\n",
    "result1 = vmap_counting(jnp.array([1.0, 2.0, 3.0]))\n",
    "print(\"After first call, count:\", call_count[0], \"(compilation trace)\")\n",
    "\n",
    "# Second call: uses cached compilation\n",
    "call_count[0] = 0\n",
    "result2 = vmap_counting(jnp.array([4.0, 5.0, 6.0]))\n",
    "print(\"After second call, count:\", call_count[0], \"(no recompilation)\")\n",
    "\n",
    "print(\"\\nResults:\")\n",
    "print(\"First:\", result1)\n",
    "print(\"Second:\", result2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 State Axis Inference\n",
    "\n",
    "`StatefulMapping` automatically infers which states need to be batched based on:\n",
    "1. Explicit `state_in_axes` filters\n",
    "2. State usage patterns during tracing\n",
    "3. Batch dimensions in state values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:14.959387Z",
     "start_time": "2025-10-11T06:22:14.948443Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inputs: [1. 2. 3.]\n",
      "Outputs: [1. 2. 3.]\n",
      "Final temp state: [1. 2. 3.]\n",
      "Param (unchanged): 1.0\n"
     ]
    }
   ],
   "source": [
    "# Example: Complex state interactions\n",
    "class StatefulComputation(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # Different types of states\n",
    "        self.temp = brainstate.ShortTermState(jnp.zeros(3))\n",
    "        self.param = brainstate.ParamState(jnp.array(1.0))\n",
    "\n",
    "    def __call__(self, x):\n",
    "        # temp is batched (accumulates per batch element)\n",
    "        self.temp.value = self.temp.value + x\n",
    "        # param is shared (broadcast across batch)\n",
    "        return self.temp.value * self.param.value\n",
    "\n",
    "\n",
    "model = StatefulComputation()\n",
    "\n",
    "# Only batch ShortTermState, ParamState is shared\n",
    "vmap_model = vmap2(\n",
    "    model,\n",
    "    in_axes=0,\n",
    "    out_axes=0,\n",
    "    state_in_axes={0: OfType(brainstate.ShortTermState)},\n",
    "    state_out_axes={0: OfType(brainstate.ShortTermState)},\n",
    ")\n",
    "\n",
    "inputs = jnp.array([1.0, 2.0, 3.0])\n",
    "outputs = vmap_model(inputs)\n",
    "\n",
    "print(\"Inputs:\", inputs)\n",
    "print(\"Outputs:\", outputs)\n",
    "print(\"Final temp state:\", model.temp.value)  # Sum of inputs\n",
    "print(\"Param (unchanged):\", model.param.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.4 Direct Use of `StatefulMapping`\n",
    "\n",
    "Advanced users can instantiate `StatefulMapping` directly for custom mapping primitives."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:14.975301Z",
     "start_time": "2025-10-11T06:22:14.965781Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Custom mapping results: [1. 2. 3.]\n",
      "Final counter: [1. 2. 3.]\n"
     ]
    }
   ],
   "source": [
    "from brainstate.transform import StatefulMapping\n",
    "import functools\n",
    "\n",
    "# Example: Using a custom mapping function\n",
    "counter_state = brainstate.ShortTermState(jnp.zeros(3))\n",
    "\n",
    "\n",
    "def increment(delta):\n",
    "    counter_state.value = counter_state.value + delta\n",
    "    return counter_state.value\n",
    "\n",
    "\n",
    "# Create StatefulMapping with custom mapping_fn\n",
    "# (In this case, we still use jax.vmap, but you could use jax.pmap, etc.)\n",
    "custom_mapping = StatefulMapping(\n",
    "    increment,\n",
    "    in_axes=0,\n",
    "    out_axes=0,\n",
    "    state_in_axes={0: OfType(brainstate.ShortTermState)},\n",
    "    state_out_axes={0: OfType(brainstate.ShortTermState)},\n",
    "    name=\"custom_increment\",\n",
    "    mapping_fn=functools.partial(jax.vmap, spmd_axis_name=None),\n",
    ")\n",
    "\n",
    "deltas = jnp.array([1.0, 2.0, 3.0])\n",
    "results = custom_mapping(deltas)\n",
    "\n",
    "print(\"Custom mapping results:\", results)\n",
    "print(\"Final counter:\", counter_state.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.5 Understanding the IR (Intermediate Representation)\n",
    "\n",
    "`StatefulMapping` compiles your function to JAX's Jaxpr (JAX expression), an intermediate representation that:\n",
    "- Represents the computation as a functional program\n",
    "- Explicitly tracks all inputs and outputs (including state values)\n",
    "- Enables optimizations and transformations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:15.012582Z",
     "start_time": "2025-10-11T06:22:14.983920Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compiled Jaxpr:\n",
      "{ \u001B[34;1mlambda \u001B[39;22m; a\u001B[35m:f32[2]\u001B[39m b\u001B[35m:f32[]\u001B[39m. \u001B[34;1mlet\n",
      "    \u001B[39;22mc\u001B[35m:key<fry>[]\u001B[39m = random_seed[impl=fry] 0:i32[]\n",
      "    d\u001B[35m:u32[2]\u001B[39m = random_unwrap c\n",
      "    e\u001B[35m:key<fry>[]\u001B[39m = random_wrap[impl=fry] d\n",
      "    f\u001B[35m:key<fry>[2]\u001B[39m = random_split[shape=(2,)] e\n",
      "    _\u001B[35m:u32[2,2]\u001B[39m = random_unwrap f\n",
      "    g\u001B[35m:f32[]\u001B[39m = convert_element_type[new_dtype=float32 weak_type=False] b\n",
      "    h\u001B[35m:f32[2]\u001B[39m = add a g\n",
      "    _\u001B[35m:f32[2]\u001B[39m = mul h 2.0:f32[]\n",
      "    i\u001B[35m:f32[]\u001B[39m = convert_element_type[new_dtype=float32 weak_type=False] b\n",
      "    j\u001B[35m:f32[2]\u001B[39m = add a i\n",
      "    k\u001B[35m:f32[2]\u001B[39m = mul j 2.0:f32[]\n",
      "  \u001B[34;1min \u001B[39;22m(k, j) }\n",
      "\n",
      "This represents the function's computation graph at an abstract level\n"
     ]
    }
   ],
   "source": [
    "# Example: Inspecting the Jaxpr\n",
    "simple_state = brainstate.State(jnp.array(1.0))\n",
    "\n",
    "\n",
    "def simple_op(x):\n",
    "    result = x + simple_state.value\n",
    "    simple_state.value = result\n",
    "    return result * 2\n",
    "\n",
    "\n",
    "# Create a simple mapping\n",
    "simple_vmap = vmap2(\n",
    "    simple_op,\n",
    "    in_axes=0,\n",
    "    state_out_axes={0: OfType(brainstate.State)},\n",
    ")\n",
    "\n",
    "# Call once to trigger compilation\n",
    "test_input = jnp.array([1.0, 2.0])\n",
    "_ = simple_vmap2(test_input)\n",
    "\n",
    "# Access the compiled Jaxpr\n",
    "cache_key = simple_vmap.get_arg_cache_key(test_input)\n",
    "jaxpr = simple_vmap.get_jaxpr_by_cache(cache_key)\n",
    "\n",
    "print(\"Compiled Jaxpr:\")\n",
    "print(jaxpr)\n",
    "print(\"\\nThis represents the function's computation graph at an abstract level\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Advanced Patterns and Best Practices\n",
    "\n",
    "### 4.1 Nested `vmap2`\n",
    "\n",
    "You can nest multiple `vmap2` calls for multi-dimensional batching."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:15.098471Z",
     "start_time": "2025-10-11T06:22:15.018588Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Matrix A shape: (3, 4)\n",
      "Matrix B shape: (3, 4)\n",
      "Result shape: (3, 4)\n",
      "Result:\n",
      "[[ 0.  1.  2.  3.]\n",
      " [ 4.  5.  6.  7.]\n",
      " [ 8.  9. 10. 11.]]\n"
     ]
    }
   ],
   "source": [
    "def matrix_elem_product(x, y):\n",
    "    \"\"\"Element-wise product.\"\"\"\n",
    "    return x * y\n",
    "\n",
    "\n",
    "# First vmap: over rows\n",
    "vmap_rows = vmap2(matrix_elem_product, in_axes=(0, 0))\n",
    "\n",
    "# Second vmap: over columns\n",
    "vmap_matrix = vmap2(vmap_rows, in_axes=(0, 0))\n",
    "\n",
    "# Create 2D inputs\n",
    "matrix_a = jnp.ones((3, 4))\n",
    "matrix_b = jnp.arange(12).reshape(3, 4)\n",
    "\n",
    "result = vmap_matrix(matrix_a, matrix_b)\n",
    "\n",
    "print(\"Matrix A shape:\", matrix_a.shape)\n",
    "print(\"Matrix B shape:\", matrix_b.shape)\n",
    "print(\"Result shape:\", result.shape)\n",
    "print(\"Result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 Combining with Other Transforms\n",
    "\n",
    "`vmap2` can be composed with other JAX transforms like `jit` and `grad`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T06:22:15.144950Z",
     "start_time": "2025-10-11T06:22:15.103478Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Inputs: [1. 2. 3.]\n",
      "Targets: [2. 4. 6.]\n",
      "Gradients: [-4.  0. 36.]\n"
     ]
    }
   ],
   "source": [
    "from brainstate.transform import grad, jit\n",
    "\n",
    "\n",
    "# Define a loss function\n",
    "def loss_fn(x, target):\n",
    "    pred = x ** 2\n",
    "    return jnp.sum((pred - target) ** 2)\n",
    "\n",
    "\n",
    "# Compose: jit -> grad -> vmap\n",
    "batched_grad = vmap2(\n",
    "    grad(loss_fn, argnums=0),\n",
    "    in_axes=(0, 0),\n",
    ")\n",
    "batched_grad_jit = jit(batched_grad)\n",
    "\n",
    "# Batch of inputs and targets\n",
    "batch_x = jnp.array([1.0, 2.0, 3.0])\n",
    "batch_targets = jnp.array([2.0, 4.0, 6.0])\n",
    "\n",
    "gradients = batched_grad_jit(batch_x, batch_targets)\n",
    "\n",
    "print(\"Inputs:\", batch_x)\n",
    "print(\"Targets:\", batch_targets)\n",
    "print(\"Gradients:\", gradients)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "In this tutorial, we covered:\n",
    "\n",
    "### 1. **`vmap2` Parameters**\n",
    "- `in_axes`: Controls how inputs are batched\n",
    "- `out_axes`: Controls where batch dimension appears in outputs\n",
    "- `axis_name`: Names the mapped axis for collective operations\n",
    "- `axis_size`: Explicitly specifies batch size when needed\n",
    "- `state_in_axes` / `state_out_axes`: Control state batching (BrainState-specific)\n",
    "- `unexpected_out_state_mapping`: Handles unexpected state writes\n",
    "\n",
    "### 2. **Random Number Semantics**\n",
    "- **Automatic key splitting**: `brainstate.random.RandomState` is automatically split per batch element\n",
    "- **Shared keys**: Use `jax.random` APIs with `in_axes=None` for shared random numbers\n",
    "- Each batch element gets independent random streams by default\n",
    "\n",
    "### 3. **`StatefulMapping` Architecture**\n",
    "- `vmap2` is a wrapper around `StatefulMapping`\n",
    "- Performs state discovery, axis mapping, and IR compilation\n",
    "- Compiles to Jaxpr (JAX intermediate representation)\n",
    "- Caches compilations for reuse\n",
    "- Manages state values before and after execution\n",
    "\n",
    "### Key Takeaways\n",
    "\n",
    "- BrainState's `vmap2` seamlessly handles stateful computations\n",
    "- Random states are automatically managed for reproducibility\n",
    "- The underlying `StatefulMapping` provides powerful abstractions for state-aware transformations\n",
    "- Understanding the IR compilation helps debug and optimize vectorized code"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
