{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# Simulation and Analysis of Excitatory-Inhibitory Neuronal Networks (E-I Networks)\n",
    "\n",
    "This example demonstrates the implementation of a classical excitatory-inhibitory (E-I) neuronal network using the `braincell` framework. By constructing an E-I network composed of Hodgkin-Huxley model neurons, you will learn how to implement hierarchical modeling from single neurons to networks in `braincell`, and analyze the spike dynamics of the network.\n",
    "\n",
    "## Preparation\n",
    "First, ensure that the necessary libraries (`braincell`, `brainstate`, `brainunit`, `matplotlib`) are installed, and import the required modules:\n"
   ],
   "id": "2912659fbdc55848"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-13T09:24:18.860294Z",
     "start_time": "2025-10-13T09:24:18.857911Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import brainpy\n",
    "import brainunit as u\n",
    "import matplotlib.pyplot as plt\n",
    "import brainstate\n",
    "import braincell"
   ],
   "id": "9d0cab5ba9e88f66",
   "outputs": [],
   "execution_count": 19
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Code Explanation\n",
    "\n",
    "### Parameter Definition\n",
    "\n",
    "First, define the key biophysical parameters, which determine the electrophysiological properties of the neurons and the network size:\n"
   ],
   "id": "7c40b3a6e85f45ae"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-13T09:24:18.963848Z",
     "start_time": "2025-10-13T09:24:18.960475Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Neuron action potential firing threshold\n",
    "V_th = -20. * u.mV\n",
    "\n",
    "# Neuron membrane area\n",
    "area = 20000 * u.um **2\n",
    "area = area.in_unit(u.cm** 2)\n",
    "\n",
    "# Membrane capacitance\n",
    "Cm = (1 * u.uF * u.cm ** -2) * area  # Total capacitance = specific capacitance × area"
   ],
   "id": "c22585ffe4f19bdd",
   "outputs": [],
   "execution_count": 20
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Define HH single neuron model\n",
    "\n",
    "Use `SingleCompartment` to construct a single neuron based on the HH model, including sodium channel `INa`, potassium channel `IK`, and leak current `IL`. These channels jointly determine the firing properties of the neuron:\n"
   ],
   "id": "12e040fbf99cff1b"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-13T09:24:19.081543Z",
     "start_time": "2025-10-13T09:24:19.077656Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class HH(braincell.SingleCompartment):\n",
    "    def __init__(self, in_size):\n",
    "        # Initialize single-compartment neuron\n",
    "        super().__init__(in_size, C=Cm, solver='ind_exp_euler')\n",
    "\n",
    "        # Sodium channel (INa)\n",
    "        self.na = braincell.ion.SodiumFixed(in_size, E=50. * u.mV)\n",
    "        self.na.add_elem(\n",
    "            # Maximum conductance\n",
    "            INa=braincell.channel.INa_TM1991(in_size, g_max=(100. * u.mS * u.cm **-2) * area, V_sh=-63. * u.mV)\n",
    "        )\n",
    "\n",
    "        # Potassium channel (IK)\n",
    "        self.k = braincell.ion.PotassiumFixed(in_size, E=-90 * u.mV)\n",
    "        self.k.add_elem(\n",
    "            # Maximum conductance\n",
    "            IK=braincell.channel.IK_TM1991(in_size, g_max=(30. * u.mS * u.cm** -2) * area, V_sh=-63. * u.mV)\n",
    "        )\n",
    "\n",
    "        # Leak current (IL)\n",
    "        self.IL = braincell.channel.IL(\n",
    "            in_size,\n",
    "            E=-60. * u.mV,\n",
    "            g_max=(5. * u.nS * u.cm **-2) * area  # Maximum conductance\n",
    "        )\n"
   ],
   "id": "c31fb6de04535794",
   "outputs": [],
   "execution_count": 21
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Define E-I Network: Connections Between Excitatory and Inhibitory Neurons\n",
    "\n",
    "Construct a network composed of excitatory (E) and inhibitory (I) neurons to simulate the commonly observed E-I balance mechanism in cortical networks:\n"
   ],
   "id": "a70a4069f943e1b"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-13T09:24:19.192953Z",
     "start_time": "2025-10-13T09:24:19.189095Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class EINet(brainstate.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # Network scale\n",
    "        self.n_exc = 3200\n",
    "        self.n_inh = 800\n",
    "        self.num = self.n_exc + self.n_inh  # Total number of neurons: 4000\n",
    "\n",
    "        # Initialize neuron population\n",
    "        self.N = HH(self.num)\n",
    "\n",
    "        # Excitatory synaptic projection\n",
    "        self.E = brainpy.state.AlignPostProj(\n",
    "            # Connection rule\n",
    "            comm=brainstate.nn.EventFixedProb(\n",
    "                self.n_exc, self.num, conn_num=0.02,  # Connection probability\n",
    "                conn_weight=6. * u.nS  # Synaptic weight\n",
    "            ),\n",
    "            # Synaptic dynamics\n",
    "            syn=brainpy.state.Expon(self.num, tau=5. * u.ms),\n",
    "            # Postsynaptic effect\n",
    "            out=brainpy.state.COBA(E=0. * u.mV),\n",
    "            post=self.N  # Projection target is neuron population N\n",
    "        )\n",
    "\n",
    "        # Inhibitory synaptic projection\n",
    "        self.I = brainpy.state.AlignPostProj(\n",
    "            # Connection rule\n",
    "            comm=brainstate.nn.EventFixedProb(\n",
    "                self.n_inh, self.num, conn_num=0.02,\n",
    "                conn_weight=67. * u.nS\n",
    "            ),\n",
    "            # Synaptic dynamics\n",
    "            syn=brainpy.state.Expon(self.num, tau=10. * u.ms),\n",
    "            # Postsynaptic effect\n",
    "            out=brainpy.state.COBA(E=-80. * u.mV),\n",
    "            post=self.N  # Projection target is neuron population N\n",
    "        )\n",
    "\n",
    "    def update(self, t):\n",
    "        # Define network update rules over time\n",
    "        with brainstate.environ.context(t=t):\n",
    "            # Get spike signals at current time\n",
    "            spk = self.N.spike.value\n",
    "\n",
    "            # Spikes from excitatory neurons drive excitatory synaptic projection\n",
    "            self.E(spk[:self.n_exc])\n",
    "\n",
    "            # Spikes from inhibitory neurons drive inhibitory synaptic projection\n",
    "            self.I(spk[self.n_exc:])\n",
    "\n",
    "            # Neurons update their states after receiving synaptic inputs, return new spike signals\n",
    "            spk = self.N(0. * u.nA)\n"
   ],
   "id": "e72dfafb35da107c",
   "outputs": [],
   "execution_count": 22
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Run network simulation\n",
    "\n",
    "Initialize the network and run the simulation, recording the spike activity of neurons at each time point:\n"
   ],
   "id": "fa97af7731e5f5e8"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-13T09:24:19.848845Z",
     "start_time": "2025-10-13T09:24:19.207206Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Initialize E-I network\n",
    "net = EINet()\n",
    "brainstate.nn.init_all_states(net)  # Initialize the states of all neurons and synapses in the network\n",
    "\n",
    "# Set simulation parameters and run\n",
    "with brainstate.environ.context(dt=0.1 * u.ms):  # Time step\n",
    "    # Generate simulation time sequence\n",
    "    times = u.math.arange(0. * u.ms, 100. * u.ms, brainstate.environ.get_dt())\n",
    "\n",
    "    # Loop to update network states\n",
    "    spikes = brainstate.transform.for_loop(\n",
    "        net.update, times,\n",
    "        pbar=brainstate.transform.ProgressBar(10)  # Display progress bar\n",
    "    )\n"
   ],
   "id": "655509eeebf82b38",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Document\\PyCharm\\Project\\braincell(collaborator)\\.venv\\Lib\\site-packages\\braintools\\surrogate.py:72: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.\n",
      "  z = jnp.asarray(x >= 0, dtype=x.dtype)\n",
      "Running for 1,000 iterations: 100%|██████████| 1000/1000 [00:00<00:00, 13624.06it/s]\n"
     ]
    }
   ],
   "execution_count": 23
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Visualize network spike activity\n",
    "\n",
    "Plot the spike data as a raster plot to intuitively show the firing patterns of neurons over time:\n"
   ],
   "id": "e595aa0208703fac"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-13T09:24:20.081862Z",
     "start_time": "2025-10-13T09:24:20.036413Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Extract spike times and neuron indices\n",
    "t_indices, n_indices = u.math.where(spikes)\n",
    "\n",
    "# Plot raster plot\n",
    "plt.scatter(times[t_indices], n_indices, s=1)\n",
    "plt.xlabel('Time (ms)')  # X-axis: time\n",
    "plt.ylabel('Neuron index')  # Y-axis: neuron index\n",
    "plt.show()\n",
    "\n",
    "# In this example, visualization cannot be completed after interruption in Jupyter, so temporarily keep the simulation results!\n"
   ],
   "id": "194cd88405441b76",
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "where requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.",
     "output_type": "error",
     "traceback": [
      "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
      "\u001B[31mTypeError\u001B[39m                                 Traceback (most recent call last)",
      "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[24]\u001B[39m\u001B[32m, line 2\u001B[39m\n\u001B[32m      1\u001B[39m \u001B[38;5;66;03m# Extract spike times and neuron indices\u001B[39;00m\n\u001B[32m----> \u001B[39m\u001B[32m2\u001B[39m t_indices, n_indices = \u001B[43mu\u001B[49m\u001B[43m.\u001B[49m\u001B[43mmath\u001B[49m\u001B[43m.\u001B[49m\u001B[43mwhere\u001B[49m\u001B[43m(\u001B[49m\u001B[43mspikes\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m      4\u001B[39m \u001B[38;5;66;03m# Plot raster plot\u001B[39;00m\n\u001B[32m      5\u001B[39m plt.scatter(times[t_indices], n_indices, s=\u001B[32m1\u001B[39m)\n",
      "\u001B[36mFile \u001B[39m\u001B[32mD:\\Document\\PyCharm\\Project\\braincell(collaborator)\\.venv\\Lib\\site-packages\\saiunit\\math\\_fun_keep_unit.py:3305\u001B[39m, in \u001B[36mwhere\u001B[39m\u001B[34m(condition, x, y, size, fill_value)\u001B[39m\n\u001B[32m   3303\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m x \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m y \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m   3304\u001B[39m     \u001B[38;5;28;01massert\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(fill_value, Quantity), \u001B[33m\"\u001B[39m\u001B[33mfill_value should not be a Quantity.\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m-> \u001B[39m\u001B[32m3305\u001B[39m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mjnp\u001B[49m\u001B[43m.\u001B[49m\u001B[43mwhere\u001B[49m\u001B[43m(\u001B[49m\u001B[43mcondition\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msize\u001B[49m\u001B[43m=\u001B[49m\u001B[43msize\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfill_value\u001B[49m\u001B[43m=\u001B[49m\u001B[43mfill_value\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m   3307\u001B[39m \u001B[38;5;28;01massert\u001B[39;00m size \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m fill_value \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m, \u001B[33m\"\u001B[39m\u001B[33msize and fill_value are only supported when x and y are not None.\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m   3308\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(x, Quantity) \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(y, Quantity):\n",
      "\u001B[36mFile \u001B[39m\u001B[32mD:\\Document\\PyCharm\\Project\\braincell(collaborator)\\.venv\\Lib\\site-packages\\jax\\_src\\numpy\\lax_numpy.py:2771\u001B[39m, in \u001B[36mwhere\u001B[39m\u001B[34m(condition, x, y, size, fill_value)\u001B[39m\n\u001B[32m   2711\u001B[39m \u001B[38;5;250m\u001B[39m\u001B[33;03m\"\"\"Select elements from two arrays based on a condition.\u001B[39;00m\n\u001B[32m   2712\u001B[39m \n\u001B[32m   2713\u001B[39m \u001B[33;03mJAX implementation of :func:`numpy.where`.\u001B[39;00m\n\u001B[32m   (...)\u001B[39m\u001B[32m   2768\u001B[39m \u001B[33;03m  Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)\u001B[39;00m\n\u001B[32m   2769\u001B[39m \u001B[33;03m\"\"\"\u001B[39;00m\n\u001B[32m   2770\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m x \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m y \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m2771\u001B[39m   \u001B[43mutil\u001B[49m\u001B[43m.\u001B[49m\u001B[43mcheck_arraylike\u001B[49m\u001B[43m(\u001B[49m\u001B[33;43m\"\u001B[39;49m\u001B[33;43mwhere\u001B[39;49m\u001B[33;43m\"\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcondition\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m   2772\u001B[39m   \u001B[38;5;28;01mreturn\u001B[39;00m nonzero(condition, size=size, fill_value=fill_value)\n\u001B[32m   2773\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n",
      "\u001B[36mFile \u001B[39m\u001B[32mD:\\Document\\PyCharm\\Project\\braincell(collaborator)\\.venv\\Lib\\site-packages\\jax\\_src\\numpy\\util.py:183\u001B[39m, in \u001B[36mcheck_arraylike\u001B[39m\u001B[34m(fun_name, emit_warning, stacklevel, *args)\u001B[39m\n\u001B[32m    180\u001B[39m   warnings.warn(msg + \u001B[33m\"\u001B[39m\u001B[33m In a future JAX release this will be an error.\u001B[39m\u001B[33m\"\u001B[39m,\n\u001B[32m    181\u001B[39m                 category=\u001B[38;5;167;01mDeprecationWarning\u001B[39;00m, stacklevel=stacklevel)\n\u001B[32m    182\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[32m--> \u001B[39m\u001B[32m183\u001B[39m   \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(msg.format(fun_name, \u001B[38;5;28mtype\u001B[39m(arg), pos))\n",
      "\u001B[31mTypeError\u001B[39m: where requires ndarray or scalar arguments, got <class 'NoneType'> at position 0."
     ]
    }
   ],
   "execution_count": 24
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Results Interpretation\n",
    "\n",
    "After running the code, you will see a raster plot, as shown below:\n",
    "\n",
    "![raster plot](../../_static/cobahh2007.png)\n",
    "\n",
    "Each point represents a spike fired by a neuron at a specific time. Typical dynamics of an E-I network have the following characteristics:\n",
    "- Asynchronous firing: Neurons fire at dispersed times without obvious synchronous rhythms.\n",
    "- Sparse activity: Most neurons fire only a few times within 100 ms.\n",
    "- No burst synchronization: Rapid feedback from inhibitory synapses prevents large-scale synchronous bursting.\n",
    "\n",
    "These features indicate that the E-I network maintains a stable and physiologically realistic activity pattern through dynamic balance between excitatory and inhibitory neurons.\n",
    "\n",
    "## Extended Exercises\n",
    "\n",
    "- Adjust inhibitory synaptic weights and observe whether the network exhibits excessive synchronization or bursting activity.\n",
    "- Increase the proportion of excitatory neurons to analyze the impact of E-I imbalance on network dynamics.\n",
    "- Extend the simulation duration to see if the network maintains stable asynchronous activity.\n",
    "\n",
    "Through these extensions, you can gain deeper insights into the critical role of E-I balance in maintaining neural network function, and the flexibility of `braincell` in modeling complex networks.\n"
   ],
   "id": "55ef3973da8a763e"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
