{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9cd0f532",
   "metadata": {},
   "source": [
    "# Event-Driven Operators"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38829ce0",
   "metadata": {},
   "source": [
    "In a spiking neural network only a small fraction of neurons fire on any given step. A dense\n",
    "matrix multiply ignores this: it pays for every connection whether or not the presynaptic neuron\n",
    "spiked. **Event-driven operators** exploit the sparsity of spike trains — their cost scales with\n",
    "the number of *active* inputs, not the total number of neurons — while producing exactly the same\n",
    "result as the dense computation.\n",
    "\n",
    "BrainState provides several, covering the connectivity patterns used in large-scale SNNs:\n",
    "\n",
    "- `EventLinear` — event-driven dense connectivity;\n",
    "- `EventFixedProb` — sparse random connectivity with a fixed connection probability;\n",
    "- `FixedNumConn` / `EventFixedNumConn` — a fixed number of connections per neuron."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0982a700",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:47:48.786987Z",
     "iopub.status.busy": "2026-05-30T16:47:48.786833Z",
     "iopub.status.idle": "2026-05-30T16:47:53.464091Z",
     "shell.execute_reply": "2026-05-30T16:47:53.463202Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'0.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "import brainstate\n",
    "import brainstate.nn as nn\n",
    "import braintools\n",
    "\n",
    "brainstate.random.seed(0)\n",
    "brainstate.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7f43d49",
   "metadata": {},
   "source": [
    "## Sparse spike vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aec4e5e7",
   "metadata": {},
   "source": [
    "A layer's input is a spike vector: mostly zeros, with a one at each neuron that fired. We build a\n",
    "population of 200 neurons with roughly 10% activity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1f2d4e11",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:47:53.466178Z",
     "iopub.status.busy": "2026-05-30T16:47:53.465713Z",
     "iopub.status.idle": "2026-05-30T16:47:53.823624Z",
     "shell.execute_reply": "2026-05-30T16:47:53.822536Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "active inputs: 25 of 200\n"
     ]
    }
   ],
   "source": [
    "n_pre, n_post = 200, 100\n",
    "spikes = (brainstate.random.rand(n_pre) < 0.1).astype(float)\n",
    "print('active inputs:', int(spikes.sum()), 'of', n_pre)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1013767d",
   "metadata": {},
   "source": [
    "## `EventLinear`: event-driven dense connectivity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a387eaff",
   "metadata": {},
   "source": [
    "`EventLinear` holds a full weight matrix, exactly like `nn.Linear`, but performs the matrix-vector\n",
    "product event-driven: it accumulates only the columns selected by spiking inputs. The result is\n",
    "identical to the dense product — only the cost differs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cea37d0e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:47:53.825971Z",
     "iopub.status.busy": "2026-05-30T16:47:53.825765Z",
     "iopub.status.idle": "2026-05-30T16:47:54.524072Z",
     "shell.execute_reply": "2026-05-30T16:47:54.523270Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output shape: (100,)\n",
      "matches dense matmul: True\n"
     ]
    }
   ],
   "source": [
    "weight = brainstate.random.randn(n_pre, n_post) * 0.1\n",
    "event_linear = nn.EventLinear(n_pre, n_post, weight=weight)\n",
    "\n",
    "event_out = event_linear(spikes)\n",
    "dense_out = spikes @ weight\n",
    "\n",
    "print('output shape:', event_out.shape)\n",
    "print('matches dense matmul:', bool(jnp.allclose(event_out, dense_out, atol=1e-5)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "631a7680",
   "metadata": {},
   "source": [
    "Spikes are naturally boolean, and `EventLinear` accepts boolean input directly — the most\n",
    "efficient form, since a spike is simply present or absent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f1f61817",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:47:54.526650Z",
     "iopub.status.busy": "2026-05-30T16:47:54.526454Z",
     "iopub.status.idle": "2026-05-30T16:47:54.689478Z",
     "shell.execute_reply": "2026-05-30T16:47:54.688551Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "boolean spikes give the same result: True\n"
     ]
    }
   ],
   "source": [
    "bool_out = event_linear(spikes.astype(bool))\n",
    "print('boolean spikes give the same result:', bool(jnp.allclose(bool_out, dense_out, atol=1e-5)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2acf875a",
   "metadata": {},
   "source": [
    "## `EventFixedProb`: sparse random connectivity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2493d2d7",
   "metadata": {},
   "source": [
    "Cortical-scale models are not densely connected — each neuron contacts a small random subset of\n",
    "targets. `EventFixedProb` represents this directly: `conn_num` is the connection probability, and\n",
    "the operator never materialises the full dense matrix, so memory scales with the number of actual\n",
    "synapses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "787a9444",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:47:54.691524Z",
     "iopub.status.busy": "2026-05-30T16:47:54.691285Z",
     "iopub.status.idle": "2026-05-30T16:47:54.823656Z",
     "shell.execute_reply": "2026-05-30T16:47:54.822551Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sparse output shape: (100,)\n"
     ]
    }
   ],
   "source": [
    "sparse_syn = nn.EventFixedProb(\n",
    "    n_pre, n_post,\n",
    "    conn_num=0.2,        # each post neuron receives ~20% of pre neurons\n",
    "    conn_weight=0.5,\n",
    ")\n",
    "\n",
    "sparse_out = sparse_syn(spikes)\n",
    "print('sparse output shape:', sparse_out.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c06cff8",
   "metadata": {},
   "source": [
    "## `FixedNumConn`: a fixed number of connections"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "633b4887",
   "metadata": {},
   "source": [
    "When biological fan-in must be controlled exactly, `FixedNumConn` wires each neuron to a fixed\n",
    "*number* of partners rather than a probability. `EventFixedNumConn` is its event-driven form for\n",
    "spiking input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0b6b67f7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-30T16:47:54.826310Z",
     "iopub.status.busy": "2026-05-30T16:47:54.825985Z",
     "iopub.status.idle": "2026-05-30T16:47:58.085469Z",
     "shell.execute_reply": "2026-05-30T16:47:58.084514Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fixed-fan-in output shape: (100,)\n"
     ]
    }
   ],
   "source": [
    "fixed_syn = nn.FixedNumConn(\n",
    "    n_pre, n_post,\n",
    "    conn_num=10,         # exactly 10 connections per neuron\n",
    "    conn_weight=0.5,\n",
    ")\n",
    "\n",
    "fixed_out = fixed_syn(spikes)\n",
    "print('fixed-fan-in output shape:', fixed_out.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed1a6130",
   "metadata": {},
   "source": [
    "## When to use event-driven operators\n",
    "\n",
    "Reach for these operators whenever a connection's input is a spike train:\n",
    "\n",
    "- **Sparse activity** — if only a few percent of neurons fire per step, event-driven evaluation\n",
    "  avoids the wasted work of a dense multiply.\n",
    "- **Sparse connectivity** — `EventFixedProb` and `EventFixedNumConn` store only real synapses, so\n",
    "  large networks fit in memory.\n",
    "- **Drop-in correctness** — the numerical result equals the dense equivalent, so you can prototype\n",
    "  with `nn.Linear` and switch to `EventLinear` for scale without changing the model's behaviour.\n",
    "\n",
    "### See also\n",
    "\n",
    "- [Building a spiking neural network](04_building_an_snn.ipynb) — these operators as synapses in a full SNN.\n",
    "- [Dynamics and integration](01_dynamics_and_integration.ipynb) — the neuron models that emit the spikes."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
