# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Sequence
import brainstate
import saiunit as u
import numpy as np
from brainstate.typing import Size
from ._base import NESTNeuron
from ._utils import is_tracer
__all__ = [
'gif_pop_psc_exp',
]
class gif_pop_psc_exp(NESTNeuron):
r"""Population of generalized integrate-and-fire neurons (GIF) with
exponential postsynaptic currents and adaptation.
``gif_pop_psc_exp`` simulates a population of spike-response model neurons
with multi-timescale adaptation and exponential postsynaptic currents.
It directly models the population activity (sum of all spikes) without
explicitly representing each individual neuron, following the algorithm of
Schwalger et al. (2017) [1]_.
This is a brainpy.state re-implementation of the NEST simulator model of
the same name, using NEST-standard parameterization.
**1. Mathematical Model**
The single neuron model is defined by the hazard function:
.. math::
h(t) = \lambda_0 \exp\left(\frac{V_m(t) - E_{\mathrm{sfa}}(t)}{\Delta_V}\right)
After each spike, the membrane potential :math:`V_m` is reset to
:math:`V_{\mathrm{reset}}`. Spike frequency adaptation is implemented by a
set of exponentially decaying traces, the sum of which is
:math:`E_{\mathrm{sfa}}`. Upon a spike, each of the adaptation traces is
incremented by the respective :math:`q_{\mathrm{sfa}}` and decays with the
respective time constant :math:`\tau_{\mathrm{sfa}}`.
**2. Population Dynamics Algorithm**
The model uses the algorithm from Figures 11 and 12 of [1]_ to simulate
the population activity directly:
1. **Membrane update**: Exact exponential integration of the subthreshold
membrane dynamics with exponential postsynaptic currents:
.. math::
C_m \frac{dV}{dt} = -\frac{C_m}{\tau_m}(V - E_L) + I_{\mathrm{syn,ex}}
+ I_{\mathrm{syn,in}} + I_e + I_{\mathrm{stim}}
2. **Adaptation update**: Multi-timescale adaptation state evolves as:
.. math::
E_{\mathrm{sfa}} = V_{T^*} + \sum_j Q_{30K,j} \cdot g_j
where :math:`g_j` tracks the convolution of past spike rates with the
adaptation kernel.
3. **Population escape rate**: For each refractory cohort, compute the
hazard-based escape (firing) probability. Free (non-refractory) neurons
and neurons emerging from refractoriness are tracked separately.
4. **Stochastic spike generation**: The expected number of spikes is
computed from the population escape rates, and the actual spike count
is drawn from either a binomial or Poisson distribution.
5. **History buffer rotation**: Spike counts, survival counts, mean
membrane potentials, and variances of each refractory cohort are
maintained in rotating circular buffers.
**3. Synaptic Currents**
Exponential postsynaptic currents follow first-order dynamics:
.. math::
\tau_{\mathrm{syn}} \frac{dy}{dt} = -y
Integrated exactly per time step:
.. math::
y(t+h) = y_{\infty} + (y(t) - y_{\infty}) \cdot e^{-h/\tau_{\mathrm{syn}}}
The membrane contribution from synaptic currents is computed via the exact
propagator for the coupled linear system of membrane and synaptic dynamics.
Connecting two population models corresponds to full connectivity of every
neuron in each population. An approximation of random connectivity can be
implemented by connecting populations using a ``bernoulli_synapse``.
Parameters
----------
in_size : int, sequence of int
Shape of the population. Typically a scalar for 1D populations.
N : int, optional
Number of neurons represented by this population node. Default: 100.
tau_m : float, optional
Membrane time constant in milliseconds. Default: 20.0.
C_m : float, optional
Membrane capacitance in picofarads. Default: 250.0.
t_ref : float, optional
Duration of absolute refractory period in milliseconds. Default: 4.0.
lambda_0 : float, optional
Firing rate at threshold in 1/s (Hz). Default: 10.0.
Delta_V : float, optional
Noise level (voltage sensitivity) of the escape rate in millivolts.
Determines how sharply the firing rate increases with membrane potential.
Default: 2.0.
E_L : float, optional
Resting (leak) potential in millivolts. Default: 0.0.
V_reset : float, optional
Reset potential after spike in millivolts. Default: 0.0.
V_T_star : float, optional
Baseline threshold level in millivolts. Effective threshold is
:math:`V_{T^*} + E_{\mathrm{sfa}}`. Default: 15.0.
I_e : float, optional
Constant external DC input current in picoamperes. Default: 0.0.
tau_syn_ex : float, optional
Excitatory synaptic time constant in milliseconds. Default: 3.0.
tau_syn_in : float, optional
Inhibitory synaptic time constant in milliseconds. Default: 6.0.
tau_sfa : sequence of float, optional
Adaptation time constants in milliseconds. Multiple timescales can be
specified as a tuple. Default: (300.0,).
q_sfa : sequence of float, optional
Adaptation kernel amplitudes in millivolts. Must have the same length
as `tau_sfa`. Default: (0.5,).
len_kernel : int, optional
History kernel length in time steps. If -1 (default), automatically
computed to ensure convergence. Default: -1.
BinoRand : bool, optional
If True, use binomial distribution for spike generation. If False,
use Poisson distribution. Default: True.
rng_seed : int, optional
Random number generator seed for reproducibility. Default: 0.
name : str, optional
Name of the population. Default: None.
Parameter Mapping
-----------------
=================== ============= =============================
gif_pop_psc_exp gif_psc_exp relation
=================== ============= =============================
tau_m g_L tau_m = C_m / g_L
N --- use N gif_psc_exp neurons
=================== ============= =============================
State Variables
---------------
The model exposes the following read-only properties:
========================== ===========================================
**State variable** **Description**
========================== ===========================================
``V_m`` Mean membrane potential (mV)
``I_syn_ex`` Excitatory synaptic current (pA)
``I_syn_in`` Inhibitory synaptic current (pA)
``n_spikes`` Number of spikes in current step
``n_expect`` Expected number of spikes
``theta_hat`` Adaptive threshold (mV)
``y0`` External current input state (pA)
========================== ===========================================
Notes
-----
- As ``gif_pop_psc_exp`` represents many neurons in one node, it may
generate many spikes per time step. The ``n_spikes`` state variable
records the number of spikes emitted by the population each step.
- The computational cost of this model is largely independent of the
number N of neurons represented.
- Defaults follow NEST C++ source for ``gif_pop_psc_exp``.
- ``lambda_0`` is specified in 1/s (as in NEST's Python interface); the
conversion factor of 0.001 to 1/ms is applied in the update equations
exactly as in NEST (the factor 0.0005 in the trapezoidal escape rate).
- Synaptic spike weights are interpreted in current units; positive
weights go to excitatory, negative to inhibitory channel.
- The model does not inherit from ``Neuron`` because it represents a
population, not a single spiking neuron with a surrogate gradient.
- Parameters are stored as plain floats (unitless) to match NEST's
internal representation.
References
----------
.. [1] Schwalger T, Deger M, Gerstner W (2017). Towards a theory of
cortical columns: From spiking neurons to interacting neural
populations of finite size. PLoS Computational Biology, 13(4),
e1005507. https://doi.org/10.1371/journal.pcbi.1005507
.. [2] NEST Simulator ``gif_pop_psc_exp`` model documentation and C++
source: ``models/gif_pop_psc_exp.h`` and
``models/gif_pop_psc_exp.cpp``.
See Also
--------
gif_psc_exp, gif_cond_exp
Examples
--------
Create a population of 100 neurons with default parameters:
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> pop = bp.gif_pop_psc_exp(1, N=100)
>>> pop.init_all_states()
Create a population with custom adaptation parameters:
.. code-block:: python
>>> pop = bp.gif_pop_psc_exp(
... 1,
... N=200,
... tau_sfa=(100.0, 500.0),
... q_sfa=(0.3, 0.6),
... lambda_0=5.0
... )
>>> pop.init_all_states()
Simulate the population with external input:
.. code-block:: python
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... pop.init_all_states()
... for _ in range(1000):
... n_spikes = pop.update(x=100.0) # 100 pA input
... print(f"Spikes: {n_spikes}, V_m: {pop.V_m:.2f}")
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
N: int = 100,
tau_m: float = 20.0, # ms
C_m: float = 250.0, # pF
t_ref: float = 4.0, # ms
lambda_0: float = 10.0, # 1/s
Delta_V: float = 2.0, # mV
E_L: float = 0.0, # mV
V_reset: float = 0.0, # mV
V_T_star: float = 15.0, # mV
I_e: float = 0.0, # pA
tau_syn_ex: float = 3.0, # ms
tau_syn_in: float = 6.0, # ms
tau_sfa: Sequence[float] = (300.0,), # ms
q_sfa: Sequence[float] = (0.5,), # mV
len_kernel: int = -1,
BinoRand: bool = True,
rng_seed: int = 0,
name: str = None,
):
super().__init__(in_size, name=name)
# Store parameters as plain floats (NEST uses raw doubles, no units)
self.N = int(N)
self.tau_m = float(tau_m)
self.C_m = float(C_m)
self.t_ref = float(t_ref)
self.lambda_0 = float(lambda_0)
self.Delta_V = float(Delta_V)
self.E_L = float(E_L)
self.V_reset = float(V_reset)
self.V_T_star = float(V_T_star)
self.I_e = float(I_e)
self.tau_syn_ex = float(tau_syn_ex)
self.tau_syn_in = float(tau_syn_in)
self.tau_sfa = tuple(float(x) for x in tau_sfa)
self.q_sfa = tuple(float(x) for x in q_sfa)
self.len_kernel = int(len_kernel)
self.BinoRand = bool(BinoRand)
self.rng_seed = int(rng_seed)
self._validate_parameters()
def _validate_parameters(self):
r"""Validate model parameters for physical consistency.
Raises
------
ValueError
If any parameter fails validation: length mismatch between tau_sfa
and q_sfa; non-positive capacitance, time constants, N, or Delta_V;
negative lambda_0 or t_ref.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.C_m, self.tau_m, self.Delta_V)):
return
if len(self.tau_sfa) != len(self.q_sfa):
raise ValueError(
f"'tau_sfa' and 'q_sfa' must have the same length. "
f"Got {len(self.tau_sfa)} and {len(self.q_sfa)}."
)
if self.C_m <= 0:
raise ValueError('Capacitance must be strictly positive.')
if self.tau_m <= 0:
raise ValueError('The membrane time constant must be strictly positive.')
if self.tau_syn_ex <= 0 or self.tau_syn_in <= 0:
raise ValueError('The synaptic time constants must be strictly positive.')
for tau in self.tau_sfa:
if tau <= 0:
raise ValueError('All adaptation time constants must be strictly positive.')
if self.N <= 0:
raise ValueError('Number of neurons must be positive.')
if self.lambda_0 < 0:
raise ValueError('lambda_0 must not be negative.')
if self.Delta_V <= 0:
raise ValueError('Delta_V must be strictly positive.')
if self.t_ref < 0:
raise ValueError('Absolute refractory period cannot be negative.')
def _adaptation_kernel(self, k: int, h: float) -> float:
r"""Compute the adaptation kernel value at lag k time steps.
Implements the adaptation kernel theta(k*h) as the sum of exponentially
decaying components (see below Eq. (87) of [1]_). This kernel determines
how past spikes contribute to the current adaptive threshold.
Parameters
----------
k : int
Time lag in units of time steps.
h : float
Integration time step in milliseconds.
Returns
-------
float
Adaptation kernel value in millivolts. Sum of all adaptation
components at time lag k*h.
Notes
-----
No division by tau is applied because the result must be in voltage
units, matching q_sfa (mV).
"""
theta_tmp = 0.0
for j in range(len(self.tau_sfa)):
theta_tmp += self.q_sfa[j] * np.exp(-k * h / self.tau_sfa[j])
return theta_tmp
def _get_history_size(self, h: float) -> int:
r"""Automatically determine a suitable history kernel size.
Computes the minimum kernel length needed to capture adaptation dynamics
by finding the time lag where the adaptation kernel decays below 10% of
the noise level Delta_V. Implements Procedure GetHistoryLength from
Fig. 11 of [1]_.
Parameters
----------
h : float
Integration time step in milliseconds.
Returns
-------
int
History kernel length in time steps. Guaranteed to be at least
max(5*tau_m/h, t_ref/h + 1).
Notes
-----
- Starts from tmax = 20000 ms and decreases until kernel(k*h) / Delta_V
reaches 0.1 or k reaches kmin = 5*tau_m/h.
- Ensures kernel length covers at least the refractory period.
- This adaptive sizing balances memory usage against accuracy.
"""
tmax = 20000.0 # ms, maximum automatic kernel length
k = int(tmax / h)
kmin = int(5 * self.tau_m / h)
while (self._adaptation_kernel(k, h) / self.Delta_V < 0.1) and k > kmin:
k -= 1
if k * h <= self.t_ref:
k = int(self.t_ref / h) + 1
return k
def _escrate(self, x):
r"""Escape rate (hazard function).
Computes the instantaneous firing rate as an exponential function of
the distance from threshold.
Parameters
----------
x : float
Distance from effective threshold in millivolts, computed as
V_m - theta_effective.
Returns
-------
float
Escape rate (hazard function) in 1/s (Hz). Equals lambda_0 when
x = 0 (at threshold).
Notes
-----
h(t) = lambda_0 * exp(x / Delta_V)
This exponential hazard is the key ingredient of the GIF model,
allowing for stochastic spike generation with a soft threshold.
"""
return self.lambda_0 * np.exp(x / self.Delta_V)
[docs]
def init_state(self, **kwargs):
r"""Initialize all population state variables and history buffers.
Allocates and initializes circular buffers for tracking refractory
cohorts, computes integration constants, and sets initial conditions
following Procedure InitPopulations from Fig. 11 of [1]_.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
Notes
-----
- Computes integration time step h from the global brainstate dt.
- Determines history kernel length (automatic if len_kernel < 1).
- Initializes all neurons in the free (non-refractory) state.
- Precomputes adaptation kernel values and propagator constants.
- Resets the random number generator to rng_seed.
"""
dftype = brainstate.environ.dftype()
dt = brainstate.environ.get_dt()
h = float(u.math.asarray(dt / u.ms))
self._h = h
# Integration constants
self._R = self.tau_m / self.C_m # membrane resistance
self._P22 = np.exp(-h / self.tau_m)
self._P20 = self.tau_m / self.C_m * (1.0 - self._P22)
self._P11_ex = np.exp(-h / self.tau_syn_ex)
self._P11_in = np.exp(-h / self.tau_syn_in)
# Determine kernel length
len_kernel = self.len_kernel
if len_kernel < 1:
len_kernel = self._get_history_size(h)
self._len_kernel = len_kernel
# Refractory period in time steps (NEST uses Time::ms -> get_steps)
# This matches NEST: V_.k_ref_ = Time( Time::ms( P_.t_ref_ ) ).get_steps()
self._k_ref = int(round(self.t_ref / h))
# Initialize population state variables
self._lambda_free = 0.0
# History buffers (rotating), length = len_kernel
self._n = np.zeros(len_kernel, dtype=dftype) # spike counts
self._m = np.zeros(len_kernel, dtype=dftype) # survival
self._v_buf = np.zeros(len_kernel, dtype=dftype) # variance of survivors
self._u = np.zeros(len_kernel, dtype=dftype) # mean of survivors
self._lambda_buf = np.zeros(len_kernel, dtype=dftype) # escape rates
# Adaptation kernel values (vectorized computation)
ks = np.arange(len_kernel)
reverse_ks = len_kernel - ks
theta_vals = np.zeros(len_kernel, dtype=dftype)
for j in range(len(self.tau_sfa)):
theta_vals += self.q_sfa[j] * np.exp(-reverse_ks * h / self.tau_sfa[j])
self._theta = theta_vals
self._theta_tld = (
self.Delta_V * (1.0 - np.exp(-theta_vals / self.Delta_V))
/ float(self.N)
)
# InitPopulations, line 7: last entry gets N
self._n[len_kernel - 1] = float(self.N)
self._m[len_kernel - 1] = float(self.N)
# InitPopulations, line 8
self._x = 0.0
self._z = 0.0
self._k0 = 0 # rotating index
# Adaptation variables (vectorized)
self._Q30 = np.array([np.exp(-h / tau) for tau in self.tau_sfa], dtype=dftype)
self._Q30K = np.array([
self.q_sfa[j] * self.tau_sfa[j] * np.exp(-h * len_kernel / self.tau_sfa[j])
for j in range(len(self.tau_sfa))
], dtype=dftype)
self._g = np.zeros(len(self.tau_sfa), dtype=dftype)
# Observable state variables
self._V_m = 0.0 # mV
self._I_syn_ex = 0.0 # pA
self._I_syn_in = 0.0 # pA
self._y0 = 0.0 # pA (external current state)
self._n_expect = 0.0
self._theta_hat = 0.0
self._n_spikes = 0
# RNG
self._rng = np.random.RandomState(self.rng_seed)
[docs]
def reset_state(self, **kwargs):
r"""Reset all population state variables to initial conditions.
Equivalent to calling ``init_state()`` again: re-initializes all
history buffers, observable states, and the random number generator
to their values immediately after construction.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
"""
self.init_state(**kwargs)
def _draw_binomial(self, n_expect: float) -> int:
r"""Draw a binomial random number of spikes, matching NEST.
Each of the N neurons fires independently with probability p = n_expect / N.
Parameters
----------
n_expect : float
Expected number of spikes (may exceed N).
Returns
-------
int
Actual number of spikes drawn from Binomial(N, p), clipped to [0, N].
Returns N if p >= 1, returns 0 if p <= 0.
Notes
-----
This implementation exactly replicates NEST's binomial spike generation
logic for gif_pop_psc_exp.
"""
p_bino = n_expect / self.N
if p_bino >= 1.0:
return self.N
elif p_bino <= 0.0:
return 0
else:
return int(self._rng.binomial(self.N, p_bino))
def _draw_poisson(self, n_expect: float) -> int:
r"""Draw a Poisson random number of spikes, matching NEST.
When n_expect is very small, switches to Bernoulli to avoid numerical
issues with the Poisson distribution.
Parameters
----------
n_expect : float
Expected number of spikes (may exceed N).
Returns
-------
int
Actual number of spikes drawn from Poisson(n_expect), clipped to [0, N].
Returns N if n_expect > N. Uses Bernoulli fallback when probability
of multiple spikes is numerically indistinguishable from single spike.
Notes
-----
This implementation exactly replicates NEST's Poisson spike generation
logic, including the numerical safeguards for very small rates.
The condition 1.0 - (n_expect + 1.0) * exp(-n_expect) > min_double
detects when P(k >= 2) is negligible.
"""
min_double = np.finfo(np.float64).tiny
if n_expect > self.N:
return self.N
elif n_expect > min_double:
# If probability of any spike is indistinguishable from that of
# one spike, use Bernoulli instead of Poisson
if 1.0 - (n_expect + 1.0) * np.exp(-n_expect) > min_double:
n_t = int(self._rng.poisson(n_expect))
else:
n_t = int(self._rng.random() < n_expect)
# Clip to [0, N]
return max(0, min(n_t, self.N))
else:
return 0
[docs]
def update(self, x=0.0):
r"""Advance the population model by one time step.
Implements the full population update algorithm from Figures 11 and 12
of [1]_, including:
1. Exact exponential integration of membrane potential and synaptic currents
2. Multi-timescale adaptation state update
3. Escape rate computation for each refractory cohort
4. Stochastic spike generation (binomial or Poisson)
5. Circular buffer rotation for history tracking
Parameters
----------
x : float, optional
External input current in picoamperes. Positive values are routed to
the excitatory channel, negative values to the inhibitory channel.
Default: 0.0.
Returns
-------
int
Number of spikes emitted by the population in this time step. Ranges
from 0 to N. Stored in the ``n_spikes`` property.
Notes
-----
- Delta inputs (spike events) are collected from ``_delta_inputs`` and
separated into excitatory and inhibitory components.
- Current inputs are collected from ``_current_inputs`` and accumulated
into the ``y0`` state for the next time step.
- The membrane potential ``V_m``, synaptic currents ``I_syn_ex`` and
``I_syn_in``, adaptive threshold ``theta_hat``, and expected spike
count ``n_expect`` are all updated and accessible via properties.
- The actual spike count is drawn stochastically from the expected count
using either a binomial or Poisson distribution (controlled by ``BinoRand``).
"""
h = self._h
# ================================================================
# Main update routine, see Fig. 11 of [1]
# ================================================================
# line 6: membrane and synapse update
h_tot = (self.I_e + self._y0) * self._P20 + self.E_L
# Get input spikes from the ring buffer / external input
# In NEST, spikes come from ring buffer weighted by synaptic weight.
# Here we receive the total weighted input directly.
x_val = float(x)
# Collect inputs: current inputs go to y0 for next step
# Delta inputs (spikes) go to excitatory/inhibitory channels
ex_input = 0.0
in_input = 0.0
if self._delta_inputs is not None:
for key in tuple(self._delta_inputs.keys()):
out = self._delta_inputs[key]
if callable(out):
out = out()
else:
self._delta_inputs.pop(key)
val = float(out)
if val > 0.0:
ex_input += val
else:
in_input += val # negative
JNA_ex = ex_input / h
JNA_in = in_input / h
# Rescale inputs to voltage scale used in [1]
JNA_ex *= self.tau_syn_ex / self.C_m
JNA_in *= self.tau_syn_in / self.C_m
# Translate synaptic currents into [1]'s definition
JNy_ex = self._I_syn_ex / self.C_m
JNy_in = self._I_syn_in / self.C_m
# Membrane update (line 10 of [1])
h_ex_tmpvar = (
self.tau_syn_ex * self._P11_ex * (JNy_ex - JNA_ex)
- self._P22 * (self.tau_syn_ex * JNy_ex - self.tau_m * JNA_ex)
)
h_in_tmpvar = (
self.tau_syn_in * self._P11_in * (JNy_in - JNA_in)
- self._P22 * (self.tau_syn_in * JNy_in - self.tau_m * JNA_in)
)
h_ex = self.tau_m * (JNA_ex + h_ex_tmpvar / (self.tau_syn_ex - self.tau_m))
h_in = self.tau_m * (JNA_in + h_in_tmpvar / (self.tau_syn_in - self.tau_m))
h_tot += h_ex + h_in
# Update EPSCs & IPSCs (line 11 of [1])
JNy_ex = JNA_ex + (JNy_ex - JNA_ex) * self._P11_ex
JNy_in = JNA_in + (JNy_in - JNA_in) * self._P11_in
# Store the updated currents, translated back
self._I_syn_ex = JNy_ex * self.C_m
self._I_syn_in = JNy_in * self.C_m
# Set new input current (for next step)
# Current inputs go to y0
new_y0 = float(x_val)
if self._current_inputs is not None:
for key in tuple(self._current_inputs.keys()):
out = self._current_inputs[key]
if callable(out):
out = out()
else:
self._current_inputs.pop(key)
new_y0 += float(out)
self._y0 = new_y0
# ================================================================
# Begin procedure UpdatePopulation, see Fig. 12 of [1]
# ================================================================
W_ = 0.0
X_ = 0.0
Y_ = 0.0
Z_ = 0.0 # line 2
self._theta_hat = self.V_T_star # line 2, initialize theta
# line 3: membrane potential update
self._V_m = (self._V_m - self.E_L) * self._P22 + h_tot
# Compute free adaptation state (vectorized, lines 4-6)
n_k0 = self._n[self._k0]
g_j_tmp = (1.0 - self._Q30) * n_k0 / (float(self.N) * h)
self._g = self._g * self._Q30 + g_j_tmp
self._theta_hat += float(np.sum(self._Q30K * self._g))
# Compute free escape rate (line 8)
lambda_tld = self._escrate(self._V_m - self._theta_hat)
# line 9: trapezoidal escape probability for free neurons
P_free = 1.0 - np.exp(-0.0005 * (self._lambda_free + lambda_tld) * h)
self._lambda_free = lambda_tld # line 10
self._theta_hat -= self._n[0] * self._theta_tld[0] # line 11
# line 12: sum up all surviving neurons
X_ = float(np.sum(self._m))
# Use a local theta_hat to preserve S_.theta_hat_ for recording
theta_hat_local = self._theta_hat
# lines 13-27: loop over non-refractory cohorts (vectorized with numpy)
n_non_ref = self._len_kernel - self._k_ref
if n_non_ref > 0:
# Rotating indices for all non-refractory cohorts (line 14)
k_arr = (self._k0 + np.arange(n_non_ref)) % self._len_kernel
# Compute per-cohort thresholds (lines 15-16).
# theta_hat_local has a sequential cumulative dependency:
# theta[i] = _theta[i] + theta_hat_local_init + sum(_n[k_arr[j]]*_theta_tld[j] for j<i)
# We use an exclusive prefix-sum to vectorize this.
n_contributions = self._n[k_arr] * self._theta_tld[:n_non_ref]
cumsum_excl = np.empty(n_non_ref, dtype=n_contributions.dtype)
cumsum_excl[0] = 0.0
if n_non_ref > 1:
cumsum_excl[1:] = np.cumsum(n_contributions[:-1])
theta_arr = self._theta[:n_non_ref] + theta_hat_local + cumsum_excl # line 15
# Update mean survivor membrane potentials (line 17, no cross-dependencies)
self._u[k_arr] = (self._u[k_arr] - self.E_L) * self._P22 + h_tot
# Escape rates (line 18) and trapezoidal probabilities (lines 19-20)
lambda_tld_arr = self.lambda_0 * np.exp(
(self._u[k_arr] - theta_arr) / self.Delta_V
)
P_lambda_arr = 0.0005 * (lambda_tld_arr + self._lambda_buf[k_arr]) * h
P_lambda_arr = np.where(
P_lambda_arr > 0.01,
1.0 - np.exp(-P_lambda_arr),
P_lambda_arr,
)
self._lambda_buf[k_arr] = lambda_tld_arr # line 21
# Accumulate sums (lines 22-24)
Y_ = float(np.sum(P_lambda_arr * self._v_buf[k_arr]))
Z_ = float(np.sum(self._v_buf[k_arr]))
W_ = float(np.sum(P_lambda_arr * self._m[k_arr]))
# Update survival and variance buffers (lines 25-26)
ompl_arr = 1.0 - P_lambda_arr
self._v_buf[k_arr] = (
ompl_arr * ompl_arr * self._v_buf[k_arr] + P_lambda_arr * self._m[k_arr]
)
self._m[k_arr] = ompl_arr * self._m[k_arr]
# line 28
if (Z_ + self._z) > 0.0:
P_Lambda_ = (Y_ + P_free * self._z) / (Z_ + self._z)
else:
P_Lambda_ = 0.0
# line 29: expected number of spikes
self._n_expect = W_ + P_free * self._x + P_Lambda_ * (self.N - X_ - self._x)
# Draw random spike count
if self.BinoRand:
self._n_spikes = self._draw_binomial(self._n_expect)
else:
self._n_spikes = self._draw_poisson(self._n_expect)
# line 31: update z
ompf = 1.0 - P_free
self._z = ompf * ompf * self._z + self._x * P_free + self._v_buf[self._k0]
# line 32: update x
self._x = self._x * ompf + self._m[self._k0]
# lines 33-36: reset buffers at current position
self._n[self._k0] = float(self._n_spikes)
self._m[self._k0] = float(self._n_spikes)
self._v_buf[self._k0] = 0.0
self._u[self._k0] = self.V_reset
self._lambda_buf[self._k0] = 0.0
# End procedure UpdatePopulation
# Shift rotating index (line 17 of Fig. 11)
self._k0 = (self._k0 + 1) % self._len_kernel
return self._n_spikes
@property
def V_m(self) -> float:
r"""Mean membrane potential in millivolts.
Returns
-------
float
Population-averaged membrane potential of all neurons (mV).
Updated each time step by exact exponential integration.
"""
return self._V_m
@property
def I_syn_ex(self) -> float:
r"""Excitatory synaptic current in picoamperes.
Returns
-------
float
Current excitatory synaptic current (pA), evolving according to
exponential dynamics with time constant tau_syn_ex.
"""
return self._I_syn_ex
@property
def I_syn_in(self) -> float:
r"""Inhibitory synaptic current in picoamperes.
Returns
-------
float
Current inhibitory synaptic current (pA), evolving according to
exponential dynamics with time constant tau_syn_in.
"""
return self._I_syn_in
@property
def n_spikes(self) -> int:
r"""Number of spikes emitted in the current time step.
Returns
-------
int
Spike count drawn stochastically from the expected rate, ranging
from 0 to N.
"""
return self._n_spikes
@property
def n_expect(self) -> float:
r"""Expected number of spikes in the current time step.
Returns
-------
float
Mean spike count computed from the population escape rates before
stochastic sampling. May exceed N in principle.
"""
return self._n_expect
@property
def theta_hat(self) -> float:
r"""Adaptive threshold for non-refractory neurons in millivolts.
Returns
-------
float
Effective threshold including all adaptation contributions (mV).
Equals V_T_star plus the sum of all adaptation traces.
"""
return self._theta_hat
@property
def y0(self) -> float:
r"""External current input state in picoamperes.
Returns
-------
float
Accumulated external current input (pA) that will be applied in
the next time step.
"""
return self._y0