{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Projections\n",
    "\n",
    "Projections are `brainpy.state` 's mechanism for connecting neural populations.\n",
    "They implement the **Communication-Synapse-Output (Comm-Syn-Out)** architecture,\n",
    "which separates connectivity, synaptic dynamics, and output computation into modular components.\n",
    "\n",
    "\n",
    "This guide provides a comprehensive understanding of projections in `brainpy.state`.\n",
    "\n",
    "\n",
    "## Overview\n",
    "\n",
    "### What are Projections?\n",
    "\n",
    "A **projection** connects a presynaptic population to a postsynaptic population through:\n",
    "\n",
    "1. **Communication (Comm)**: How spikes propagate through connections\n",
    "2. **Synapse (Syn)**: Temporal filtering and synaptic dynamics\n",
    "3. **Output (Out)**: How synaptic currents affect postsynaptic neurons\n",
    "\n",
    "\n",
    "**Key benefits:**\n",
    "\n",
    "- Modular design (swap components independently)\n",
    "- Biologically realistic (separate connectivity and dynamics)\n",
    "- Efficient (optimized sparse operations)\n",
    "- Flexible (combine components in different ways)\n",
    "\n",
    "\n",
    "### The Comm-Syn-Out Architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:46:03.343127Z",
     "start_time": "2025-11-13T11:46:03.339748Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:35.133817Z",
     "iopub.status.busy": "2026-05-11T06:19:35.133599Z",
     "iopub.status.idle": "2026-05-11T06:19:37.351945Z",
     "shell.execute_reply": "2026-05-11T06:19:37.350925Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import brainstate\n",
    "import braintools\n",
    "import saiunit as u\n",
    "import numpy as np\n",
    "\n",
    "import brainpy.state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:46:27.174926Z",
     "start_time": "2025-11-13T11:46:27.170504Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.354573Z",
     "iopub.status.busy": "2026-05-11T06:19:37.354187Z",
     "iopub.status.idle": "2026-05-11T06:19:37.358283Z",
     "shell.execute_reply": "2026-05-11T06:19:37.357272Z"
    }
   },
   "outputs": [],
   "source": [
    "brainstate.environ.set(dt=0.1 * u.ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```text\n",
    "Presynaptic        Communication         Synapse          Output        Postsynaptic\n",
    "Population    ──►  (Connectivity)  ──►  (Dynamics)  ──►  (Current) ──►  Population\n",
    "\n",
    "Spikes        ──►  Weight matrix   ──►  g(t)        ──►  I_syn     ──►  Neurons\n",
    "                   Sparse/Dense         Expon/Alpha     CUBA/COBA\n",
    "```\n",
    "\n",
    "**Flow:**\n",
    "\n",
    "1. Presynaptic spikes arrive\n",
    "2. Communication: Spikes propagate through connectivity matrix\n",
    "3. Synapse: Temporal dynamics filter the signal\n",
    "4. Output: Convert to current/conductance\n",
    "5. Postsynaptic neurons receive input"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "### Types of Projections\n",
    "\n",
    "BrainPy provides two main projection types:\n",
    "\n",
    "**AlignPostProj**\n",
    "   - Align synaptic states with postsynaptic neurons\n",
    "   - Most common for standard neural networks\n",
    "   - Efficient memory layout\n",
    "\n",
    "**AlignPreProj**\n",
    "   - Align synaptic states with presynaptic neurons\n",
    "   - Useful for certain learning rules\n",
    "   - Different memory organization\n",
    "\n",
    "For most use cases, use `AlignPostProj`.\n",
    "\n",
    "## Communication Layer\n",
    "\n",
    "The Communication layer defines **how spikes propagate** through connections.\n",
    "\n",
    "### Dense Connectivity\n",
    "\n",
    "All neurons potentially connected (though weights may be zero).\n",
    "\n",
    "**Use case:** Small networks, fully connected layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.360707Z",
     "iopub.status.busy": "2026-05-11T06:19:37.360494Z",
     "iopub.status.idle": "2026-05-11T06:19:37.918507Z",
     "shell.execute_reply": "2026-05-11T06:19:37.917537Z"
    }
   },
   "outputs": [],
   "source": [
    "# Dense linear transformation\n",
    "comm = brainstate.nn.Linear(\n",
    "    100,  # in_size\n",
    "    50,  # out_size\n",
    "    w_init=braintools.init.KaimingNormal(),\n",
    "    b_init=None  # No bias for synapses\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Characteristics:**\n",
    "\n",
    "- Memory: O(n_pre × n_post)\n",
    "- Computation: Full matrix multiplication\n",
    "- Best for: Small networks, fully connected architectures\n",
    "\n",
    "### Sparse Connectivity\n",
    "\n",
    "Only a subset of connections exist (biologically realistic).\n",
    "\n",
    "**Use case:** Large networks, biological connectivity patterns\n",
    "\n",
    "#### Event-Based Fixed Probability\n",
    "\n",
    "Connect neurons with fixed probability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.921231Z",
     "iopub.status.busy": "2026-05-11T06:19:37.920963Z",
     "iopub.status.idle": "2026-05-11T06:19:37.951568Z",
     "shell.execute_reply": "2026-05-11T06:19:37.950522Z"
    }
   },
   "outputs": [],
   "source": [
    "# Sparse random connectivity (2% connection probability)\n",
    "comm = brainstate.nn.EventFixedProb(\n",
    "    1000,  # pre_size\n",
    "    800,  # post_size\n",
    "    conn_num=0.02,  # 2% connectivity\n",
    "    conn_weight=0.5  # Synaptic weight (unitless for event-based)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Characteristics:**\n",
    "\n",
    "- Memory: O(n_pre × n_post × prob)\n",
    "- Computation: Only active connections\n",
    "- Best for: Large-scale networks, biological models\n",
    "\n",
    "#### Event-Based All-to-All\n",
    "\n",
    "All neurons connected (but stored sparsely)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.953891Z",
     "iopub.status.busy": "2026-05-11T06:19:37.953628Z",
     "iopub.status.idle": "2026-05-11T06:19:37.957369Z",
     "shell.execute_reply": "2026-05-11T06:19:37.956342Z"
    }
   },
   "outputs": [],
   "source": [
    "# All-to-all sparse (event-driven)\n",
    "comm = brainstate.nn.AllToAll(\n",
    "    100,  # pre_size\n",
    "    100,  # post_size\n",
    "    0.3  # Unitless weight\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Event-Based One-to-One\n",
    "\n",
    "One-to-one mapping (same size populations)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.959139Z",
     "iopub.status.busy": "2026-05-11T06:19:37.958979Z",
     "iopub.status.idle": "2026-05-11T06:19:37.962313Z",
     "shell.execute_reply": "2026-05-11T06:19:37.961528Z"
    }
   },
   "outputs": [],
   "source": [
    "size = 100\n",
    "weight = 1.0\n",
    "\n",
    "# One-to-one connections\n",
    "comm = brainstate.nn.OneToOne(\n",
    "    size,\n",
    "    weight  # Unitless weight\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Use case:** Feedforward pathways, identity mappings\n",
    "\n",
    "\n",
    "## Synapse Layer\n",
    "\n",
    "The Synapse layer defines **temporal dynamics** of synaptic transmission.\n",
    "\n",
    "### Exponential Synapse\n",
    "\n",
    "Single exponential decay (most common).\n",
    "\n",
    "**Dynamics:**\n",
    "\n",
    "\n",
    "$$\n",
    "\\tau \\frac{dg}{dt} = -g + \\sum_k \\delta(t - t_k)\n",
    "$$\n",
    "\n",
    "**Implementation:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.964911Z",
     "iopub.status.busy": "2026-05-11T06:19:37.964571Z",
     "iopub.status.idle": "2026-05-11T06:19:37.968114Z",
     "shell.execute_reply": "2026-05-11T06:19:37.967481Z"
    }
   },
   "outputs": [],
   "source": [
    "# Exponential synapse with 5ms time constant\n",
    "syn = brainpy.state.Expon(\n",
    "    in_size=100,  # Postsynaptic population size\n",
    "    tau=5.0 * u.ms  # Decay time constant\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Characteristics:**\n",
    "\n",
    "- Single time constant\n",
    "- Fast computation\n",
    "- Good for most applications\n",
    "\n",
    "**When to use:** Default choice for most models\n",
    "\n",
    "### Alpha Synapse\n",
    "\n",
    "Dual exponential with rise and decay.\n",
    "\n",
    "**Dynamics:**\n",
    "\n",
    "\n",
    "$$\n",
    "\\tau \\frac{dg}{dt} = -g + h \\\\\n",
    "\\tau \\frac{dh}{dt} = -h + \\sum_k \\delta(t - t_k)\n",
    "$$\n",
    "**Implementation:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.970458Z",
     "iopub.status.busy": "2026-05-11T06:19:37.970269Z",
     "iopub.status.idle": "2026-05-11T06:19:37.973665Z",
     "shell.execute_reply": "2026-05-11T06:19:37.972656Z"
    }
   },
   "outputs": [],
   "source": [
    "# Alpha synapse\n",
    "syn = brainpy.state.Alpha(\n",
    "    in_size=100,\n",
    "    tau=10.0 * u.ms  # Characteristic time\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Characteristics:**\n",
    "\n",
    "- Realistic rise time\n",
    "- Smoother response\n",
    "- Slightly slower computation\n",
    "\n",
    "**When to use:** When rise time matters, more biological realism\n",
    "\n",
    "### NMDA Synapse\n",
    "\n",
    "Voltage-dependent NMDA receptors.\n",
    "\n",
    "**Dynamics:**\n",
    "\n",
    "\n",
    "$$\n",
    "g_{NMDA} = \\frac{g}{1 + \\eta [Mg^{2+}] e^{-\\gamma V}}\n",
    "$$\n",
    "**Implementation:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.975818Z",
     "iopub.status.busy": "2026-05-11T06:19:37.975605Z",
     "iopub.status.idle": "2026-05-11T06:19:37.979512Z",
     "shell.execute_reply": "2026-05-11T06:19:37.978698Z"
    }
   },
   "outputs": [],
   "source": [
    "# NMDA receptor\n",
    "syn = brainpy.state.BioNMDA(\n",
    "    in_size=100,\n",
    "    T_dur=100.0 * u.ms,  # Slow decay\n",
    "    T=2.0 * u.ms,  # Fast rise\n",
    "    alpha1=0.5 / u.mM,  # Mg²⁺ sensitivity\n",
    "    g_initializer=1.2 * u.mM  # Mg²⁺ concentration\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Characteristics:**\n",
    "\n",
    "- Voltage-dependent\n",
    "- Slow kinetics\n",
    "- Important for plasticity\n",
    "\n",
    "**When to use:** Long-term potentiation, working memory models\n",
    "\n",
    "### AMPA Synapse\n",
    "\n",
    "Fast glutamatergic transmission."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:42:57.610829Z",
     "start_time": "2025-11-13T11:42:57.606831Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.982157Z",
     "iopub.status.busy": "2026-05-11T06:19:37.981683Z",
     "iopub.status.idle": "2026-05-11T06:19:37.985941Z",
     "shell.execute_reply": "2026-05-11T06:19:37.985173Z"
    }
   },
   "outputs": [],
   "source": [
    "# AMPA receptor (fast excitation)\n",
    "syn = brainpy.state.AMPA(\n",
    "    in_size=100,\n",
    "    beta=0.5 / u.ms,  # Fast decay (~2ms)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**When to use:** Fast excitatory transmission\n",
    "\n",
    "### GABA Synapse\n",
    "\n",
    "Inhibitory transmission.\n",
    "\n",
    "**GABAa (fast):**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:43:19.181623Z",
     "start_time": "2025-11-13T11:43:19.177719Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.988050Z",
     "iopub.status.busy": "2026-05-11T06:19:37.987803Z",
     "iopub.status.idle": "2026-05-11T06:19:37.991731Z",
     "shell.execute_reply": "2026-05-11T06:19:37.990950Z"
    }
   },
   "outputs": [],
   "source": [
    "# GABAa receptor (fast inhibition)\n",
    "syn = brainpy.state.GABAa(\n",
    "    in_size=100,\n",
    "    beta=0.16 / u.ms,  # ~6ms decay\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**GABAb (slow):**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:43:24.009249Z",
     "start_time": "2025-11-13T11:43:24.005919Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.993886Z",
     "iopub.status.busy": "2026-05-11T06:19:37.993649Z",
     "iopub.status.idle": "2026-05-11T06:19:37.997475Z",
     "shell.execute_reply": "2026-05-11T06:19:37.996522Z"
    }
   },
   "outputs": [],
   "source": [
    "# GABAb receptor (slow inhibition)\n",
    "syn = brainpy.state.GABAa(\n",
    "    in_size=100,\n",
    "    T_dur=150.0 * u.ms,  # Very slow\n",
    "    T=3.5 * u.ms\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**When to use:**\n",
    "- GABAa: Fast inhibition, cortical networks\n",
    "- GABAb: Slow inhibition, rhythm generation\n",
    "\n",
    "### Custom Synapses\n",
    "\n",
    "Create custom synaptic dynamics by subclassing `Synapse`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:43:26.083188Z",
     "start_time": "2025-11-13T11:43:26.077812Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:37.999601Z",
     "iopub.status.busy": "2026-05-11T06:19:37.999404Z",
     "iopub.status.idle": "2026-05-11T06:19:38.004817Z",
     "shell.execute_reply": "2026-05-11T06:19:38.004048Z"
    }
   },
   "outputs": [],
   "source": [
    "class DoubleExpSynapse(brainpy.state.Synapse):\n",
    "    \"\"\"Custom synapse with two time constants.\"\"\"\n",
    "\n",
    "    def __init__(self, size, tau_fast=2 * u.ms, tau_slow=10 * u.ms, **kwargs):\n",
    "        super().__init__(size, **kwargs)\n",
    "        self.tau_fast = tau_fast\n",
    "        self.tau_slow = tau_slow\n",
    "\n",
    "        # State variables\n",
    "        self.g_fast = brainstate.ShortTermState(jnp.zeros(size))\n",
    "        self.g_slow = brainstate.ShortTermState(jnp.zeros(size))\n",
    "\n",
    "    def reset_state(self, batch_size=None):\n",
    "        shape = self.varshape if batch_size is None else (batch_size, *self.varshape)\n",
    "        self.g_fast.value = jnp.zeros(shape)\n",
    "        self.g_slow.value = jnp.zeros(shape)\n",
    "\n",
    "    def update(self, x):\n",
    "        dt = brainstate.environ.get_dt()\n",
    "\n",
    "        # Fast component\n",
    "        dg_fast = -self.g_fast.value / self.tau_fast.to_decimal(u.ms)\n",
    "        self.g_fast.value += dg_fast * dt.to_decimal(u.ms) + x * 0.7\n",
    "\n",
    "        # Slow component\n",
    "        dg_slow = -self.g_slow.value / self.tau_slow.to_decimal(u.ms)\n",
    "        self.g_slow.value += dg_slow * dt.to_decimal(u.ms) + x * 0.3\n",
    "\n",
    "        return self.g_fast.value + self.g_slow.value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Output Layer\n",
    "\n",
    "The Output layer defines **how synaptic conductance affects neurons**.\n",
    "\n",
    "### CUBA (Current-Based)\n",
    "\n",
    "Synaptic conductance directly becomes current.\n",
    "\n",
    "**Model:**\n",
    "\n",
    "\n",
    "$$\n",
    "I_{syn} = g_{syn}\n",
    "$$\n",
    "**Implementation:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:43:28.874215Z",
     "start_time": "2025-11-13T11:43:28.869215Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:38.006965Z",
     "iopub.status.busy": "2026-05-11T06:19:38.006732Z",
     "iopub.status.idle": "2026-05-11T06:19:38.023393Z",
     "shell.execute_reply": "2026-05-11T06:19:38.022457Z"
    }
   },
   "outputs": [],
   "source": [
    "# Define population sizes\n",
    "pre_size = 100\n",
    "post_size = 50\n",
    "\n",
    "# Define connectivity parameters\n",
    "conn_num = 0.1\n",
    "conn_weight = 0.5\n",
    "\n",
    "comm = brainstate.nn.EventFixedProb(\n",
    "    pre_size, post_size, conn_num, conn_weight\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Characteristics:**\n",
    "\n",
    "- Simple and fast\n",
    "- No voltage dependence\n",
    "- Good for rate-based models\n",
    "\n",
    "**When to use:**\n",
    "- Abstract models\n",
    "- When voltage dependence not important\n",
    "- Faster computation needed\n",
    "\n",
    "### COBA (Conductance-Based)\n",
    "\n",
    "Synaptic conductance with reversal potential.\n",
    "\n",
    "**Model:**\n",
    "\n",
    "\n",
    "$$\n",
    "I_{syn} = g_{syn} (E_{syn} - V_{post})\n",
    "$$\n",
    "**Implementation:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:43:29.757135Z",
     "start_time": "2025-11-13T11:43:29.753741Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:38.025749Z",
     "iopub.status.busy": "2026-05-11T06:19:38.025563Z",
     "iopub.status.idle": "2026-05-11T06:19:38.029465Z",
     "shell.execute_reply": "2026-05-11T06:19:38.028771Z"
    }
   },
   "outputs": [],
   "source": [
    "# Excitatory conductance-based\n",
    "out_exc = brainpy.state.COBA(E=0.0 * u.mV)\n",
    "\n",
    "# Inhibitory conductance-based\n",
    "out_inh = brainpy.state.COBA(E=-80.0 * u.mV)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Characteristics:**\n",
    "\n",
    "- Voltage-dependent\n",
    "- Biologically realistic\n",
    "- Self-limiting (saturates near reversal)\n",
    "\n",
    "**When to use:**\n",
    "- Biologically detailed models\n",
    "- When voltage dependence matters\n",
    "- Shunting inhibition needed\n",
    "\n",
    "### MgBlock (NMDA)\n",
    "\n",
    "Voltage-dependent magnesium block for NMDA."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:43:31.336047Z",
     "start_time": "2025-11-13T11:43:31.332070Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:38.031513Z",
     "iopub.status.busy": "2026-05-11T06:19:38.031283Z",
     "iopub.status.idle": "2026-05-11T06:19:38.034451Z",
     "shell.execute_reply": "2026-05-11T06:19:38.033827Z"
    }
   },
   "outputs": [],
   "source": [
    "# NMDA with Mg²⁺ block\n",
    "out_nmda = brainpy.state.MgBlock(\n",
    "    E=0.0 * u.mV,\n",
    "    cc_Mg=1.2 * u.mM,\n",
    "    alpha=0.062 / u.mV,\n",
    "    beta=3.57\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**When to use:** NMDA receptors, voltage-dependent plasticity\n",
    "\n",
    "## Complete Projection Examples\n",
    "\n",
    "### Example 1: Simple Feedforward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:47:02.873592Z",
     "start_time": "2025-11-13T11:47:02.423022Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:38.037164Z",
     "iopub.status.busy": "2026-05-11T06:19:38.036790Z",
     "iopub.status.idle": "2026-05-11T06:19:38.462632Z",
     "shell.execute_reply": "2026-05-11T06:19:38.461819Z"
    }
   },
   "outputs": [],
   "source": [
    "# Create populations\n",
    "pre = brainpy.state.LIF(100, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)\n",
    "post = brainpy.state.LIF(50, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)\n",
    "\n",
    "# Create projection: 100 → 50 neurons\n",
    "proj = brainpy.state.AlignPostProj(\n",
    "    comm=brainstate.nn.EventFixedProb(\n",
    "        100,  # pre_size\n",
    "        50,  # post_size\n",
    "        conn_num=0.1,  # 10% connectivity\n",
    "        conn_weight=0.5 * u.mS  # Weight\n",
    "    ),\n",
    "    syn=brainpy.state.Expon(\n",
    "        in_size=50,  # Postsynaptic size\n",
    "        tau=5.0 * u.ms\n",
    "    ),\n",
    "    out=brainpy.state.CUBA(),\n",
    "    post=post  # Postsynaptic population\n",
    ")\n",
    "\n",
    "# Initialize\n",
    "brainstate.nn.init_all_states([pre, post, proj])\n",
    "\n",
    "\n",
    "# Simulate\n",
    "def step(t, i, inp):\n",
    "    with brainstate.environ.context(t=t, i=i):\n",
    "        # Update neurons\n",
    "        pre(inp)\n",
    "\n",
    "        # Get presynaptic spikes\n",
    "        pre_spikes = pre.get_spike()\n",
    "\n",
    "        # Update projection\n",
    "        proj(pre_spikes)\n",
    "\n",
    "        post(0.0 * u.nA)  # Projection provides input\n",
    "\n",
    "        return pre.get_spike(), post.get_spike()\n",
    "\n",
    "\n",
    "indices = np.arange(1000)\n",
    "times = indices * brainstate.environ.get_dt()\n",
    "inputs = brainstate.random.uniform(30., 50., indices.shape) * u.nA\n",
    "_ = brainstate.transform.for_loop(step, times, indices, inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 2: Excitatory-Inhibitory Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:51:00.592366Z",
     "start_time": "2025-11-13T11:50:59.048927Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:38.465232Z",
     "iopub.status.busy": "2026-05-11T06:19:38.464994Z",
     "iopub.status.idle": "2026-05-11T06:19:39.172419Z",
     "shell.execute_reply": "2026-05-11T06:19:39.168040Z"
    }
   },
   "outputs": [],
   "source": [
    "class EINetwork(brainstate.nn.Module):\n",
    "    def __init__(self, n_exc=800, n_inh=200):\n",
    "        super().__init__()\n",
    "\n",
    "        # Populations\n",
    "        self.E = brainpy.state.LIF(n_exc, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=15 * u.ms)\n",
    "        self.I = brainpy.state.LIF(n_inh, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)\n",
    "\n",
    "        # E → E projection (AMPA, excitatory)\n",
    "        self.E2E = brainpy.state.AlignPostProj(\n",
    "            comm=brainstate.nn.EventFixedProb(n_exc, n_exc, conn_num=0.02, conn_weight=0.6 * u.mS),\n",
    "            syn=brainpy.state.Expon(n_exc, tau=2. * u.ms),\n",
    "            out=brainpy.state.COBA(E=0.0 * u.mV),\n",
    "            post=self.E\n",
    "        )\n",
    "\n",
    "        # E → I projection (AMPA, excitatory)\n",
    "        self.E2I = brainpy.state.AlignPostProj(\n",
    "            comm=brainstate.nn.EventFixedProb(n_exc, n_inh, conn_num=0.02, conn_weight=0.6 * u.mS),\n",
    "            syn=brainpy.state.Expon(n_inh, tau=2. * u.ms),\n",
    "            out=brainpy.state.COBA(E=0.0 * u.mV),\n",
    "            post=self.I\n",
    "        )\n",
    "\n",
    "        # I → E projection (GABAa, inhibitory)\n",
    "        self.I2E = brainpy.state.AlignPostProj(\n",
    "            comm=brainstate.nn.EventFixedProb(n_inh, n_exc, conn_num=0.02, conn_weight=6.7 * u.mS),\n",
    "            syn=brainpy.state.Expon(n_exc, tau=6. * u.ms),\n",
    "            out=brainpy.state.COBA(E=-80.0 * u.mV),\n",
    "            post=self.E\n",
    "        )\n",
    "\n",
    "        # I → I projection (GABAa, inhibitory)\n",
    "        self.I2I = brainpy.state.AlignPostProj(\n",
    "            comm=brainstate.nn.EventFixedProb(n_inh, n_inh, conn_num=0.02, conn_weight=6.7 * u.mS),\n",
    "            syn=brainpy.state.Expon(n_inh, tau=6. * u.ms),\n",
    "            out=brainpy.state.COBA(E=-80.0 * u.mV),\n",
    "            post=self.I\n",
    "        )\n",
    "\n",
    "    def update(self, i, inp_e, inp_i):\n",
    "        t = brainstate.environ.get_dt() * i\n",
    "        with brainstate.environ.context(t=t, i=i):\n",
    "            # Get spikes BEFORE updating neurons\n",
    "            spk_e = self.E.get_spike()\n",
    "            spk_i = self.I.get_spike()\n",
    "\n",
    "            # Update all projections\n",
    "            self.E2E(spk_e)\n",
    "            self.E2I(spk_e)\n",
    "            self.I2E(spk_i)\n",
    "            self.I2I(spk_i)\n",
    "\n",
    "            # Update neurons (projections provide synaptic input)\n",
    "            self.E(inp_e)\n",
    "            self.I(inp_i)\n",
    "\n",
    "            return spk_e, spk_i\n",
    "\n",
    "\n",
    "net = EINetwork()\n",
    "brainstate.nn.init_all_states(net)\n",
    "_ = brainstate.transform.for_loop(net.update, indices, inputs, inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 3: Multi-Timescale Synapses\n",
    "\n",
    "Combine AMPA (fast) and NMDA (slow) for realistic excitation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:39.183052Z",
     "iopub.status.busy": "2026-05-11T06:19:39.182582Z",
     "iopub.status.idle": "2026-05-11T06:19:39.199400Z",
     "shell.execute_reply": "2026-05-11T06:19:39.198054Z"
    }
   },
   "outputs": [],
   "source": [
    "class DualExcitatory(brainstate.nn.Module):\n",
    "    \"\"\"E → E with both AMPA and NMDA.\"\"\"\n",
    "\n",
    "    def __init__(self, n_pre=100, n_post=100):\n",
    "        super().__init__()\n",
    "\n",
    "        self.post = brainpy.state.LIF(n_post, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)\n",
    "\n",
    "        # Fast AMPA component\n",
    "        self.ampa_proj = brainpy.state.AlignPostProj(\n",
    "            comm=brainstate.nn.EventFixedProb(n_pre, n_post, conn_num=0.1, conn_weight=0.3 * u.mS),\n",
    "            syn=brainpy.state.AMPA(n_post, tau=2.0 * u.ms),\n",
    "            out=brainpy.state.COBA(E=0.0 * u.mV),\n",
    "            post=self.post\n",
    "        )\n",
    "\n",
    "        # Slow NMDA component\n",
    "        self.nmda_proj = brainpy.state.AlignPostProj(\n",
    "            comm=brainstate.nn.EventFixedProb(n_pre, n_post, conn_num=0.1, conn_weight=0.3 * u.mS),\n",
    "            syn=brainpy.state.NMDA(n_post, tau_decay=100.0 * u.ms, tau_rise=2.0 * u.ms),\n",
    "            out=brainpy.state.MgBlock(E=0.0 * u.mV, cc_Mg=1.2 * u.mM),\n",
    "            post=self.post\n",
    "        )\n",
    "\n",
    "    def update(self, t, i, pre_spikes):\n",
    "        with brainstate.environ.context(t=t, i=i):\n",
    "            # Both projections share same presynaptic spikes\n",
    "            self.ampa_proj(pre_spikes)\n",
    "            self.nmda_proj(pre_spikes)\n",
    "\n",
    "            # Post receives combined input\n",
    "            self.post(0.0 * u.nA)\n",
    "\n",
    "            return self.post.get_spike()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example 4: Delay Projections\n",
    "\n",
    "Add synaptic delays to projections."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-11-13T11:57:15.058629Z",
     "start_time": "2025-11-13T11:57:14.596654Z"
    },
    "execution": {
     "iopub.execute_input": "2026-05-11T06:19:39.201835Z",
     "iopub.status.busy": "2026-05-11T06:19:39.201614Z",
     "iopub.status.idle": "2026-05-11T06:19:39.979189Z",
     "shell.execute_reply": "2026-05-11T06:19:39.971934Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# To implement delay, use a separate Delay module\n",
    "delay_time = 5.0 * u.ms\n",
    "\n",
    "\n",
    "# Create a network with delay\n",
    "class DelayedProjection(brainstate.nn.Module):\n",
    "    def __init__(self, pre_size, post_size):\n",
    "        super().__init__()\n",
    "\n",
    "        # Define post_neurons for demonstration\n",
    "        self.post = brainpy.state.LIF(100, V_rest=-65 * u.mV, V_th=-50 * u.mV, tau=10 * u.ms)\n",
    "        self.delay = self.post.output_delay(delay_time)\n",
    "\n",
    "        # Standard projection\n",
    "        self.proj = brainpy.state.AlignPostProj(\n",
    "            comm=brainstate.nn.EventFixedProb(pre_size, post_size, conn_num=0.1, conn_weight=0.5 * u.mS),\n",
    "            syn=brainpy.state.Expon(post_size, tau=5.0 * u.ms),\n",
    "            out=brainpy.state.CUBA(),\n",
    "            post=self.post\n",
    "        )\n",
    "\n",
    "    def update(self, inp=0. * u.nA):\n",
    "        # Retrieve delayed spikes\n",
    "        delayed_spikes = self.delay()\n",
    "        # Update projection with delayed spikes\n",
    "        self.proj(delayed_spikes)\n",
    "        self.post(inp)\n",
    "        # Store current spikes in delay buffer\n",
    "        self.delay(self.post.get_spike())\n",
    "\n",
    "    def step_run(self, i, inp):\n",
    "        t = brainstate.environ.get_dt() * i\n",
    "        with brainstate.environ.context(t=t, i=i):\n",
    "            # Update post neurons\n",
    "            self.update(inp)\n",
    "            return self.post.get_spike()\n",
    "\n",
    "\n",
    "net = DelayedProjection(100, 100)\n",
    "brainstate.nn.init_all_states(net)\n",
    "_ = brainstate.transform.for_loop(net.step_run, indices, inputs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Ecosystem-py",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
