{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b0e8ae9a",
   "metadata": {},
   "source": [
    "# SNN Online Learning with BrainTrace\n",
    "\n",
    "**Train a spiking neural network using ES-D-RTRL**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "595ee30d",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "Spiking Neural Networks (SNNs) process information through discrete spike events, mimicking the communication mechanism of biological neurons. Unlike traditional artificial neural networks that operate on continuous activations, SNNs emphasize the timing and frequency of spikes, making them inherently suited for temporal data processing.\n",
    "\n",
    "**Online learning** is a natural fit for SNNs because they process inputs sequentially, one time step at a time. Instead of storing the entire computation graph for backpropagation through time (BPTT), online learning algorithms update weight gradients incrementally at each time step. This eliminates the need to unroll the network over the full sequence length, resulting in constant memory usage with respect to sequence length.\n",
    "\n",
    "In this tutorial, we will use **ES-D-RTRL** (Eligibility-trace Scalable Decoupled Real-Time Recurrent Learning), an efficient online learning algorithm provided by `braintrace`. ES-D-RTRL factorizes the eligibility trace into input and output components, achieving **O(B(I+O))** memory complexity (where B is batch size, I is input dimension, and O is output dimension). This makes it highly scalable for large spiking networks.\n",
    "\n",
    "**What you will learn:**\n",
    "- How to build an SNN model using `brainstate` neurons and `braintrace.nn` layers\n",
    "- How to set up online learning with `braintrace.ES_D_RTRL` (ES-D-RTRL)\n",
    "- How to train the SNN on random spike data\n",
    "- The key differences between D-RTRL and ES-D-RTRL"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58cdd32f",
   "metadata": {},
   "source": [
    "## 1. Setup\n",
    "\n",
    "First, let us import the required packages. The key components are:\n",
    "- `brainstate`: provides neuron models (LIF), state management, and JAX-based transformations\n",
    "- `braintrace`: provides online learning algorithms and ETP-aware neural network layers\n",
    "- `braintools`: provides initializers, optimizers, surrogate gradient functions, and metrics\n",
    "- `saiunit`: provides physical units (ms, mV, etc.) for biologically meaningful parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b133c828",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:29:47.493193Z",
     "iopub.status.busy": "2026-04-17T09:29:47.492972Z",
     "iopub.status.idle": "2026-04-17T09:29:50.401980Z",
     "shell.execute_reply": "2026-04-17T09:29:50.400723Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import brainstate\n",
    "import braintools\n",
    "import braintrace\n",
    "import saiunit as u\n",
    "import brainpy.state\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea5313f3",
   "metadata": {},
   "source": [
    "## 2. SNN Model\n",
    "\n",
    "We build a simple recurrent SNN with the following architecture:\n",
    "\n",
    "1. **Input + Recurrent Projection**: A `braintrace.nn.Linear` layer that projects the concatenation of input spikes and recurrent spikes into the hidden layer. Using `braintrace.nn.Linear` (instead of a plain matrix multiply) marks this projection for participation in online learning via ETP primitives.\n",
    "\n",
    "2. **LIF Neuron**: A Leaky Integrate-and-Fire neuron from `brainpy.state.LIF`. The LIF neuron integrates its input current, fires a spike when the membrane potential exceeds a threshold, and then resets. We use `braintools.surrogate.ReluGrad()` as the surrogate gradient function for differentiability.\n",
    "\n",
    "3. **Readout**: A `braintrace.nn.LeakyRateReadout` that applies leaky integration to the recurrent spikes and produces a continuous output signal for classification. This layer is also ETP-aware.\n",
    "\n",
    "The recurrent connectivity is achieved by concatenating the neuron's own spike output with the external input at each time step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cef17cc9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:29:50.406592Z",
     "iopub.status.busy": "2026-04-17T09:29:50.405853Z",
     "iopub.status.idle": "2026-04-17T09:29:50.412169Z",
     "shell.execute_reply": "2026-04-17T09:29:50.411320Z"
    }
   },
   "outputs": [],
   "source": [
    "class LIF_SNN(brainstate.nn.Module):\n",
    "    \"\"\"A simple recurrent SNN with LIF neurons for online learning.\"\"\"\n",
    "\n",
    "    def __init__(self, n_in, n_rec, n_out, tau_mem=20. * u.ms, tau_out=20. * u.ms):\n",
    "        super().__init__()\n",
    "\n",
    "        # Input + recurrent projection (ETP-aware: participates in online learning).\n",
    "        # Weights are in current units so that ``I * R`` lands in mV inside the LIF\n",
    "        # neuron (LIF integrates ``-V + I*R``; ``mA * ohm = mV`` matches V_th below).\n",
    "        self.linear = braintrace.nn.Linear(\n",
    "            n_in + n_rec, n_rec,\n",
    "            w_init=braintools.init.KaimingNormal(unit=u.mA),\n",
    "            b_init=braintools.init.ZeroInit(unit=u.mA),\n",
    "        )\n",
    "\n",
    "        # LIF neuron with surrogate gradient for differentiability.\n",
    "        self.neuron = brainpy.state.LIF(\n",
    "            n_rec,\n",
    "            tau=tau_mem,\n",
    "            R=1. * u.ohm,\n",
    "            V_th=1. * u.mV,\n",
    "            V_reset=0. * u.mV,\n",
    "            V_rest=0. * u.mV,\n",
    "            spk_fun=braintools.surrogate.ReluGrad(),\n",
    "            spk_reset='soft',\n",
    "        )\n",
    "\n",
    "        # Readout layer (ETP-aware: participates in online learning).\n",
    "        self.readout = braintrace.nn.LeakyRateReadout(\n",
    "            n_rec, n_out,\n",
    "            tau=tau_out,\n",
    "            w_init=braintools.init.KaimingNormal(),\n",
    "        )\n",
    "\n",
    "    def update(self, spike_input):\n",
    "        # Concatenate input spikes with recurrent spikes.\n",
    "        rec_spk = self.neuron.get_spike()\n",
    "        x = jnp.concatenate([spike_input, rec_spk], axis=-1)\n",
    "\n",
    "        # Linear projection -> LIF neuron dynamics -> readout.\n",
    "        self.neuron(self.linear(x))\n",
    "        return self.readout(self.neuron())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3938aef7",
   "metadata": {},
   "source": [
    "Let us verify that the model can be instantiated and produce output for a single sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8f0885a3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:29:50.414465Z",
     "iopub.status.busy": "2026-04-17T09:29:50.414180Z",
     "iopub.status.idle": "2026-04-17T09:29:52.699967Z",
     "shell.execute_reply": "2026-04-17T09:29:52.699344Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output shape: (10,)\n",
      "Output values: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "with brainstate.environ.context(dt=1. * u.ms):\n",
    "    model = LIF_SNN(n_in=50, n_rec=128, n_out=10)\n",
    "    brainstate.nn.init_all_states(model)\n",
    "\n",
    "    # Single time step with random spike input\n",
    "    test_input = jnp.array(np.random.binomial(1, 0.1, (50,)).astype(np.float32))\n",
    "    output = model(test_input)\n",
    "    print(f\"Output shape: {output.shape}\")\n",
    "    print(f\"Output values: {output}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43f58871",
   "metadata": {},
   "source": [
    "## 3. Training with ES-D-RTRL\n",
    "\n",
    "We set up online learning using `braintrace.ES_D_RTRL` (also exposed as `braintrace.pp_prop` and the lower-level `braintrace.IODimVjpAlgorithm`). The key steps are:\n",
    "\n",
    "1. **Wrap the model** with `ES_D_RTRL`, supplying the `decay_or_rank` parameter:\n",
    "\n",
    "   * **`decay_or_rank=float` in (0, 1]** -- exponentially-smoothed trace. The value is the decay factor applied per step; `0.99` is a common choice. Memory cost: `O(B * (I + O))` per layer. Approximation: the trace is a leaky moving average of recent activity.\n",
    "   * **`decay_or_rank=int >= 1`** -- low-rank trace. The integer is the rank used to factorise the trace. Memory cost: `O(B * rank * (I + O))`. Approximation: the trace is projected onto the top `rank` modes.\n",
    "\n",
    "   Pick the **decay form** when you want a single hyper-parameter that you can sweep cheaply, and the **rank form** when you want a tunable accuracy/memory trade-off independent of any time-scale assumption.\n",
    "\n",
    "2. **Initialize per-sample states** using `vmap_new_states` so each sample in the batch has independent hidden states and eligibility traces.\n",
    "\n",
    "3. **Compile the graph** by calling `algo.compile_graph(sample_input)`.\n",
    "\n",
    "4. **Define the gradient function** using `brainstate.transform.grad`.\n",
    "\n",
    "`braintrace.D_RTRL` is the alternative algorithm; it stores the *full* parameter-dimension trace (`O(B * theta)`) and is exact rather than approximate. Use `D_RTRL` when memory permits and `ES_D_RTRL` for larger networks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2e58fc27",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:29:52.701588Z",
     "iopub.status.busy": "2026-04-17T09:29:52.701443Z",
     "iopub.status.idle": "2026-04-17T09:29:52.707626Z",
     "shell.execute_reply": "2026-04-17T09:29:52.706752Z"
    }
   },
   "outputs": [],
   "source": [
    "def train_snn(n_steps=100, n_epochs=50, batch_size=32, n_in=50, n_rec=128, n_out=10, lr=1e-3):\n",
    "    \"\"\"Train a recurrent SNN using ES-D-RTRL online learning.\"\"\"\n",
    "\n",
    "    with brainstate.environ.context(dt=1. * u.ms):\n",
    "        # Create model and optimizer\n",
    "        model = LIF_SNN(n_in, n_rec, n_out)\n",
    "        opt = braintools.optim.Adam(lr)\n",
    "        weights = model.states(brainstate.ParamState)\n",
    "        opt.register_trainable_weights(weights)\n",
    "\n",
    "        @brainstate.transform.jit\n",
    "        def train_step(inputs, targets):\n",
    "            # Wrap model with ES-D-RTRL (decay_or_rank=0.99 means decay factor of 0.99)\n",
    "            algo = braintrace.ES_D_RTRL(model, decay_or_rank=0.99)\n",
    "\n",
    "            # Initialize per-sample states (each sample in the batch gets independent states)\n",
    "            @brainstate.transform.vmap_new_states(state_tag='new', axis_size=inputs.shape[1])\n",
    "            def init():\n",
    "                brainstate.nn.init_all_states(model)\n",
    "                algo.compile_graph(inputs[0, 0])\n",
    "\n",
    "            init()\n",
    "            vmapped_algo = brainstate.nn.Vmap(algo, vmap_states='new')\n",
    "\n",
    "            def loss_fn(inp):\n",
    "                out = vmapped_algo(inp)\n",
    "                loss = braintools.metric.softmax_cross_entropy_with_integer_labels(\n",
    "                    out, targets\n",
    "                ).mean()\n",
    "                return loss, out\n",
    "\n",
    "            def scan_step(prev_grads, inp):\n",
    "                f_grad = brainstate.transform.grad(\n",
    "                    loss_fn, weights, has_aux=True, return_value=True\n",
    "                )\n",
    "                cur_grads, cur_loss, out = f_grad(inp)\n",
    "                next_grads = jax.tree.map(lambda a, b: a + b, prev_grads, cur_grads)\n",
    "                return next_grads, cur_loss\n",
    "\n",
    "            # Accumulate gradients over all time steps\n",
    "            grads = jax.tree.map(jnp.zeros_like, weights.to_dict_values())\n",
    "            grads, losses = brainstate.transform.scan(scan_step, grads, inputs)\n",
    "\n",
    "            # Clip gradients and update weights\n",
    "            grads = brainstate.functional.clip_grad_norm(grads, 1.0)\n",
    "            opt.update(grads)\n",
    "\n",
    "            return losses.mean()\n",
    "\n",
    "        # Training loop with random spike data\n",
    "        losses = []\n",
    "        for epoch in range(n_epochs):\n",
    "            # Generate random spike inputs (Bernoulli with firing probability 0.1)\n",
    "            inputs = np.random.binomial(1, 0.1, (n_steps, batch_size, n_in)).astype(np.float32)\n",
    "            targets = np.random.randint(0, n_out, batch_size)\n",
    "\n",
    "            loss = train_step(jnp.array(inputs), jnp.array(targets))\n",
    "            losses.append(float(loss))\n",
    "\n",
    "            if epoch % 10 == 0:\n",
    "                print(f\"Epoch {epoch:3d}, Loss: {loss:.4f}\")\n",
    "\n",
    "        return losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72b54601",
   "metadata": {},
   "source": [
    "Let us run the training loop. Note that the first epoch will be slower due to JAX's JIT compilation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3113f9a6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:29:52.709538Z",
     "iopub.status.busy": "2026-04-17T09:29:52.709310Z",
     "iopub.status.idle": "2026-04-17T09:29:55.207498Z",
     "shell.execute_reply": "2026-04-17T09:29:55.206577Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_287205/2090577167.py:45: DeprecationWarning: Accessing 'clip_grad_norm' from 'brainstate.functional' is deprecated and will be removed in a future version. Use 'brainstate.nn.clip_grad_norm' instead.\n",
      "  grads = brainstate.functional.clip_grad_norm(grads, 1.0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch   0, Loss: 2.3026\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch  10, Loss: 2.3026\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch  20, Loss: 2.3026\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch  30, Loss: 2.3026\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch  40, Loss: 2.3026\n"
     ]
    }
   ],
   "source": [
    "losses = train_snn(n_steps=50, n_epochs=50, batch_size=32, n_in=50, n_rec=128, n_out=10, lr=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6b876ff",
   "metadata": {},
   "source": [
    "We can visualize the training loss curve to confirm that the network is learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "516e6f4a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-04-17T09:29:55.210035Z",
     "iopub.status.busy": "2026-04-17T09:29:55.209786Z",
     "iopub.status.idle": "2026-04-17T09:29:55.241184Z",
     "shell.execute_reply": "2026-04-17T09:29:55.240596Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_287205/3725002578.py:10: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown\n",
      "  plt.show()\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(8, 4))\n",
    "plt.plot(losses)\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss (ES-D-RTRL Online Learning)')\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdbdc871",
   "metadata": {},
   "source": [
    "## 4. Key Differences: D-RTRL vs ES-D-RTRL\n",
    "\n",
    "BrainTrace provides two main online learning algorithms. Understanding their trade-offs helps you choose the right one for your application.\n",
    "\n",
    "| Aspect | D-RTRL (`ParamDimVjpAlgorithm`) | ES-D-RTRL (`ES_D_RTRL`) |\n",
    "|---|---|---|\n",
    "| **Eligibility Trace** | Full trace per weight parameter | Factorized into input/output components |\n",
    "| **Memory Complexity** | O(B * theta) where theta = total parameters | O(B * (I + O)) where I = input dim, O = output dim |\n",
    "| **Computation** | Exact gradient computation | Approximation via low-rank factorization |\n",
    "| **Scalability** | Suitable for small networks | Scales to large networks (hundreds/thousands of neurons) |\n",
    "| **Use Cases** | Research requiring exact gradients | Practical SNN training, large-scale networks |\n",
    "| **BrainTrace API** | `braintrace.D_RTRL(model)` | `braintrace.ES_D_RTRL(model, decay_or_rank)` |\n",
    "\n",
    "### When to use which?\n",
    "\n",
    "- **D-RTRL** stores the full eligibility trace for each weight, giving exact online gradients. However, this requires O(B * theta) memory, which grows linearly with the number of parameters. For a network with N hidden neurons and a recurrent weight matrix of size N x N, the trace has N^4 entries per sample. This limits D-RTRL to small networks (typically < 100 neurons).\n",
    "\n",
    "- **ES-D-RTRL** factorizes the eligibility trace into input and output components, reducing memory to O(B * (I + O)). The `decay_or_rank` parameter controls the approximation: a float value (e.g., 0.99) sets the trace decay factor, while an integer value sets the rank of the low-rank approximation. ES-D-RTRL is the recommended choice for SNNs, where networks often have hundreds or thousands of neurons.\n",
    "\n",
    "Both algorithms use the same `braintrace.nn.Linear` and `braintrace.nn.LeakyRateReadout` layers. Switching between them requires only changing the algorithm wrapper:\n",
    "\n",
    "```python\n",
    "# D-RTRL (exact, high memory)\n",
    "algo = braintrace.ParamDimVjpAlgorithm(model)\n",
    "\n",
    "# ES-D-RTRL (approximate, scalable)\n",
    "algo = braintrace.ES_D_RTRL(model, decay_or_rank=0.99)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa87fa0a",
   "metadata": {},
   "source": [
    "## 5. Summary\n",
    "\n",
    "In this tutorial, we demonstrated how to train a spiking neural network with online learning using BrainTrace. Here are the key takeaways:\n",
    "\n",
    "1. **Model Construction**: Use `braintrace.nn.Linear` and `braintrace.nn.LeakyRateReadout` for layers that should participate in online learning (ETP-aware). Combine them with spiking neuron models from `brainpy.state` (e.g., `LIF`).\n",
    "\n",
    "2. **Online Learning Setup**: Wrap the model with `braintrace.ES_D_RTRL` (ES-D-RTRL), call `compile_graph()` to trace the computation graph, and use `brainstate.transform.grad` to compute gradients at each time step.\n",
    "\n",
    "3. **Scalability**: ES-D-RTRL achieves O(B(I+O)) memory complexity, making it practical for large spiking networks. The `decay_or_rank` parameter controls the trace approximation quality.\n",
    "\n",
    "4. **Batching**: Use `brainstate.transform.vmap_new_states` and `brainstate.nn.Vmap` to process multiple samples in parallel, with each sample maintaining independent hidden states and eligibility traces.\n",
    "\n",
    "For more advanced topics, including training on real neuromorphic datasets (N-MNIST) and comparing online learning with BPTT, see the detailed tutorials:\n",
    "- [SNN Online Learning (detailed)](./snn_online_learning-en.ipynb)\n",
    "- [Key Concepts](./concepts-en.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
