{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60001",
   "metadata": {},
   "source": [
    "# Batching Strategies\n",
    "\n",
    "Online learning algorithms need to handle batched data efficiently. In braintrace, there are two main batching strategies:\n",
    "\n",
    "- **Vmap-based batching** (recommended): Compile the computation graph for a single sample, then use `vmap` to automatically vectorize across the batch dimension.\n",
    "- **Single-sample mode**: Process one sample at a time, without any batching.\n",
    "\n",
    "The choice of strategy affects how model states are initialized and how the online learning algorithm is called.\n",
    "\n",
    "This tutorial walks through each strategy with concrete examples and shows how to build a full training loop using vmap-based batching."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60002",
   "metadata": {},
   "source": [
    "## Vmap-Based Batching (Recommended)\n",
    "\n",
    "The recommended approach is to compile the online learning graph for a **single sample**, then leverage JAX's `vmap` to parallelize across the batch. The key steps are:\n",
    "\n",
    "1. Create the online learning algorithm (e.g., `D_RTRL`).\n",
    "2. Use `brainstate.transform.vmap_new_states` to initialize per-sample states and compile the graph with a single-sample input shape.\n",
    "3. Wrap the algorithm with `brainstate.nn.Vmap` for parallel execution across the batch.\n",
    "4. Call the vmapped algorithm on batched inputs.\n",
    "\n",
    "This pattern keeps the model definition simple (single-sample logic) while gaining efficient batch parallelism automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b2c3d4e5f60003",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintools\n",
    "import braintrace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b2c3d4e5f60004",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleGRU(brainstate.nn.Module):\n",
    "    def __init__(self, n_in, n_rec, n_out):\n",
    "        super().__init__()\n",
    "        self.rnn = braintrace.nn.GRUCell(n_in, n_rec)\n",
    "        self.out = braintrace.nn.Linear(n_rec, n_out)\n",
    "\n",
    "    def update(self, x):\n",
    "        return self.out(self.rnn(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b2c3d4e5f60005",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SimpleGRU(10, 64, 5)\n",
    "batch_size = 16\n",
    "\n",
    "# Step 1: Create the algorithm\n",
    "algo = braintrace.D_RTRL(model)\n",
    "\n",
    "# Step 2: Initialize per-sample states via vmap\n",
    "@brainstate.transform.vmap_new_states(state_tag='new', axis_size=batch_size)\n",
    "def init():\n",
    "    brainstate.nn.init_all_states(model)\n",
    "    algo.compile_graph(jnp.zeros(10))  # single sample shape\n",
    "init()\n",
    "\n",
    "# Step 3: Wrap for parallel execution\n",
    "algo_vmapped = brainstate.nn.Vmap(algo, vmap_states='new')\n",
    "\n",
    "# Step 4: Run on batched input\n",
    "x_batch = jnp.ones((batch_size, 10))\n",
    "out = algo_vmapped(x_batch)\n",
    "print(\"Output shape:\", out.shape)  # (16, 5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60006",
   "metadata": {},
   "source": [
    "**How it works:**\n",
    "\n",
    "- `vmap_new_states` runs the initialization function once but creates `axis_size` independent copies of all model states. The `state_tag='new'` labels these states so they can be identified later.\n",
    "- `brainstate.nn.Vmap(algo, vmap_states='new')` wraps the algorithm so that each call automatically splits the batch input across the per-sample states, runs the forward pass independently for each sample, and stacks the outputs.\n",
    "- The model itself only ever sees single-sample inputs -- all batch handling is transparent."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60007",
   "metadata": {},
   "source": [
    "## Single-Sample Mode\n",
    "\n",
    "For debugging or situations where batch processing is unnecessary, you can compile and run the algorithm on individual samples directly. No `vmap` or state replication is needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b2c3d4e5f60008",
   "metadata": {},
   "outputs": [],
   "source": [
    "model2 = SimpleGRU(10, 64, 5)\n",
    "brainstate.nn.init_all_states(model2)\n",
    "\n",
    "algo2 = braintrace.D_RTRL(model2)\n",
    "algo2.compile_graph(jnp.zeros(10))\n",
    "\n",
    "# Process one sample at a time\n",
    "x_single = jnp.ones(10)\n",
    "out = algo2(x_single)\n",
    "print(\"Single sample output shape:\", out.shape)  # (5,)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60009",
   "metadata": {},
   "source": [
    "This mode is straightforward: initialize the model, compile the graph, and call the algorithm. It is useful for step-by-step debugging or when processing a single stream of data."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60010",
   "metadata": {},
   "source": [
    "## Multi-Step Data\n",
    "\n",
    "braintrace provides `SingleStepData` and `MultiStepData` wrappers to control how the algorithm processes input along the time dimension.\n",
    "\n",
    "- **`SingleStepData`**: Wraps data for a single time step. The algorithm processes it as one forward pass.\n",
    "- **`MultiStepData`**: Wraps a sequence of time steps. The algorithm internally scans over all steps in the sequence.\n",
    "\n",
    "This is useful when you want to pass an entire sequence to the algorithm and have it handle the temporal loop internally, rather than manually iterating over time steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b2c3d4e5f60011",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Single-step: process one time step at a time\n",
    "x_single = braintrace.SingleStepData(jnp.ones(10))\n",
    "\n",
    "# Multi-step: process a sequence\n",
    "sequence = jnp.ones((20, 10))  # 20 time steps, 10 features\n",
    "x_multi = braintrace.MultiStepData(sequence)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60012",
   "metadata": {},
   "source": [
    "When a `MultiStepData` object is passed to the algorithm, it will iterate over the first axis (time steps) internally. When a `SingleStepData` object (or a plain array) is passed, the algorithm processes it as a single forward step."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60013",
   "metadata": {},
   "source": [
    "## Full Training Loop with Vmap Batching\n",
    "\n",
    "Below is a complete example that combines vmap-based batching with a temporal training loop. The pattern is:\n",
    "\n",
    "1. **Initialize** model states and compile the graph for a single sample.\n",
    "2. **Vmap** the algorithm across the batch dimension.\n",
    "3. **Scan** over time steps, accumulating gradients at each step.\n",
    "4. **Update** parameters with the accumulated gradients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b2c3d4e5f60014",
   "metadata": {},
   "outputs": [],
   "source": [
    "@brainstate.transform.jit\n",
    "def train_step(inputs, targets):\n",
    "    \"\"\"inputs: (n_steps, batch_size, n_in), targets: (batch_size,)\"\"\"\n",
    "    weights = model.states(brainstate.ParamState)\n",
    "    algo = braintrace.D_RTRL(model)\n",
    "\n",
    "    @brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1])\n",
    "    def init():\n",
    "        brainstate.nn.init_all_states(model)\n",
    "        algo.compile_graph(inputs[0, 0])\n",
    "    init()\n",
    "    vmapped_algo = brainstate.nn.Vmap(algo, vmap_states='new')\n",
    "\n",
    "    def step_fn(prev_grads, inp):\n",
    "        def loss_fn(inp):\n",
    "            out = vmapped_algo(inp)\n",
    "            return jnp.mean((out - targets) ** 2)\n",
    "        cur_grads = brainstate.transform.grad(loss_fn, weights)(inp)\n",
    "        return jax.tree.map(lambda a, b: a + b, prev_grads, cur_grads), None\n",
    "\n",
    "    grads = jax.tree.map(jnp.zeros_like, weights.to_dict_values())\n",
    "    grads, _ = brainstate.transform.scan(step_fn, grads, inputs)\n",
    "    return grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1b2c3d4e5f60015",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example usage\n",
    "model = SimpleGRU(10, 64, 5)\n",
    "inputs = jnp.ones((20, 16, 10))  # 20 steps, batch 16, 10 features\n",
    "targets = jnp.zeros((16, 5))\n",
    "grads = train_step(inputs, targets)\n",
    "print(\"Gradient keys:\", list(grads.keys()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60016",
   "metadata": {},
   "source": [
    "**What happens in `train_step`:**\n",
    "\n",
    "1. `weights` collects all `ParamState` objects from the model.\n",
    "2. The `init` function, decorated with `vmap_new_states`, initializes per-sample hidden states and compiles the computation graph using a single-sample input.\n",
    "3. `vmapped_algo` wraps the algorithm for batch-parallel execution.\n",
    "4. `brainstate.transform.scan` iterates over the time dimension (`inputs` has shape `(n_steps, batch_size, n_in)`). At each step, `step_fn` computes the loss and its gradients with respect to `weights`, then accumulates them.\n",
    "5. The returned `grads` dictionary can be passed to an optimizer (e.g., `braintools.optim.Adam`) for a parameter update."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1b2c3d4e5f60017",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- **Vmap-based batching is recommended** for most use cases. It keeps model code simple (single-sample logic) while achieving efficient batch parallelism.\n",
    "- The workflow is: **compile for a single sample, then vmap across the batch**.\n",
    "- `SingleStepData` and `MultiStepData` control whether the algorithm processes one time step or scans over an entire sequence internally.\n",
    "- The typical training pattern is: **init states -> compile graph -> vmap -> scan over time steps -> accumulate gradients -> update parameters**."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbformat_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}