# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable, Sequence
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
from ._utils import is_tracer, alpha_propagator_p31_p32
__all__ = [
'glif_psc',
]
class glif_psc(NESTNeuron):
r"""Current-based generalized leaky integrate-and-fire (GLIF) neuron model.
The ``glif_psc`` model implements the five-level GLIF model hierarchy
from the Allen Institute [1]_, featuring alpha-function shaped synaptic
currents, after-spike currents (ASC), spike-dependent threshold adaptation,
and voltage-dependent threshold modulation. Exact integration via
propagator matrices ensures numerical stability and matches NEST's
implementation.
**Model Hierarchy**
The five GLIF models are:
* **GLIF Model 1** (LIF) — Traditional leaky integrate-and-fire
* **GLIF Model 2** (LIF_R) — LIF with biologically defined reset rules
* **GLIF Model 3** (LIF_ASC) — LIF with after-spike currents
* **GLIF Model 4** (LIF_R_ASC) — LIF with reset rules and after-spike
currents
* **GLIF Model 5** (LIF_R_ASC_A) — LIF with reset rules, after-spike
currents, and a voltage-dependent threshold
Model mechanism selection is based on three boolean parameters:
+--------+---------------------------+----------------------+--------------------+
| Model | spike_dependent_threshold | after_spike_currents | adapting_threshold |
+========+===========================+======================+====================+
| GLIF1 | False | False | False |
+--------+---------------------------+----------------------+--------------------+
| GLIF2 | True | False | False |
+--------+---------------------------+----------------------+--------------------+
| GLIF3 | False | True | False |
+--------+---------------------------+----------------------+--------------------+
| GLIF4 | True | True | False |
+--------+---------------------------+----------------------+--------------------+
| GLIF5 | True | True | True |
+--------+---------------------------+----------------------+--------------------+
Mathematical Formulation
------------------------
**1. Membrane Dynamics**
The membrane potential :math:`U` (stored relative to :math:`E_L`) evolves
according to exact integration (linear dynamics):
.. math::
U(t+dt) = U(t) \cdot P_{33} + (I_e + I_\mathrm{ASC,sum}) \cdot P_{30}
+ \sum_k \left( P_{31,k} \cdot y_{1,k} + P_{32,k} \cdot y_{2,k} \right)
where the propagator matrix elements are:
.. math::
P_{33} = \exp\left(-\frac{dt}{\tau_m}\right), \quad
P_{30} = \frac{\tau_m}{C_m} \left(1 - P_{33}\right), \quad
\tau_m = \frac{C_m}{g}
and :math:`P_{31,k}`, :math:`P_{32,k}` are computed via the
``IAFPropagatorAlpha`` algorithm that handles the singularity when
:math:`\tau_m \approx \tau_{\mathrm{syn},k}`.
**2. Synaptic Currents (Alpha Function)**
Each receptor port has a current modeled by an alpha function with two
state variables :math:`y_{1,k}` and :math:`y_{2,k}`:
.. math::
y_{2,k}(t+dt) = P_{21,k} \cdot y_{1,k}(t) + P_{22,k} \cdot y_{2,k}(t)
.. math::
y_{1,k}(t+dt) = P_{11,k} \cdot y_{1,k}(t)
where:
.. math::
P_{11,k} = P_{22,k} = \exp(-dt / \tau_{\mathrm{syn},k}), \quad
P_{21,k} = dt \cdot P_{11,k}
On a presynaptic spike of weight :math:`w`:
.. math::
y_{1,k} \leftarrow y_{1,k} + w \cdot \frac{e}{\tau_{\mathrm{syn},k}}
The alpha function is normalized such that an event of weight 1.0 results
in a peak current of 1 pA at :math:`t = \tau_\mathrm{syn}`.
**3. After-Spike Currents (GLIF3/4/5)**
After-spike currents (ASC) are modeled as exponentially decaying currents
with exact integration. Each ASC component :math:`I_j` decays with rate
:math:`k_j`:
.. math::
I_j(t+dt) = I_j(t) \cdot \exp(-k_j \cdot dt)
The time-averaged ASC over a step uses the stable coefficient:
.. math::
\bar{I}_j = \frac{1 - \exp(-k_j \cdot dt)}{k_j \cdot dt} \cdot I_j(t)
On spike, ASC values are reset:
.. math::
I_j \leftarrow \Delta I_j + I_j \cdot r_j \cdot \exp(-k_j \cdot t_\mathrm{ref})
**4. Spike-Dependent Threshold (GLIF2/4/5)**
The spike component of the threshold decays exponentially:
.. math::
\theta_s(t+dt) = \theta_s(t) \cdot \exp(-b_s \cdot dt)
On spike, after refractory decay:
.. math::
\theta_s \leftarrow \theta_s \cdot \exp(-b_s \cdot t_\mathrm{ref})
+ \Delta\theta_s
Voltage reset (with spike-dependent threshold):
.. math::
U \leftarrow f_v \cdot U_\mathrm{old} + V_\mathrm{add}
**5. Voltage-Dependent Threshold (GLIF5)**
The voltage component of the threshold evolves according to:
.. math::
\theta_v(t+dt) = \phi \cdot (U_\mathrm{old} - \beta) \cdot P_\mathrm{decay}
+ \frac{1}{P_{\theta,v}} \cdot \left(\theta_v(t)
- \phi \cdot (U_\mathrm{old} - \beta)
- \frac{a_v}{b_v} \cdot \beta \right)
+ \frac{a_v}{b_v} \cdot \beta
where :math:`\phi = a_v / (b_v - g/C_m)`,
:math:`P_\mathrm{decay} = \exp(-g \cdot dt / C_m)`,
:math:`P_{\theta,v} = \exp(b_v \cdot dt)`,
and :math:`\beta = (I_e + I_\mathrm{ASC,sum}) / g`.
Overall threshold:
.. math::
\theta = \theta_\infty + \theta_s + \theta_v
Spike condition (checked after voltage update):
.. math::
U > \theta
**6. Numerical Integration and Update Order**
NEST uses exact integration for the linear subthreshold dynamics (via
propagator matrices). The discrete-time update order per simulation step
is:
1. Record :math:`U_\mathrm{old}` (relative to :math:`E_L`).
2. If not refractory:
a. Decay spike threshold component.
b. Compute time-averaged ASC and decay ASC values.
c. Update membrane potential:
:math:`U = U_\mathrm{old} \cdot P_{33} + (I + ASC_\mathrm{sum}) \cdot P_{30} + \sum P_{31} y_1 + P_{32} y_2`.
d. Compute voltage-dependent threshold component (using :math:`U_\mathrm{old}`).
e. Update total threshold.
f. If :math:`U > \theta`: emit spike, apply reset rules.
3. If refractory: decrement counter, hold U at :math:`U_\mathrm{old}`.
4. Update synaptic current state variables:
:math:`y_2 = P_{21} y_1 + P_{22} y_2`, then :math:`y_1 = P_{11} y_1`.
5. Add incoming spike current jumps (scaled by :math:`e / \tau_\mathrm{syn}`).
6. Update external current input :math:`I`.
7. Record and save :math:`U_\mathrm{old}` for next step.
Parameters
----------
in_size : Size
Shape of the neuron population. Can be tuple of ints or single int.
g : ArrayLike, optional
Membrane (leak) conductance. Default: 9.43 nS.
E_L : ArrayLike, optional
Resting membrane potential. Default: -78.85 mV.
V_th : ArrayLike, optional
Instantaneous threshold voltage (absolute). Default: -51.68 mV.
C_m : ArrayLike, optional
Membrane capacitance. Default: 58.72 pF.
t_ref : ArrayLike, optional
Absolute refractory period. Default: 3.75 ms.
V_reset : ArrayLike, optional
Reset potential (absolute; used for GLIF1/3). Default: -78.85 mV.
th_spike_add : float, optional
Threshold additive constant after spike (mV). Default: 0.37.
th_spike_decay : float, optional
Spike threshold decay rate (/ms). Default: 0.009.
voltage_reset_fraction : float, optional
Voltage fraction coefficient after spike. Default: 0.20.
voltage_reset_add : float, optional
Voltage additive constant after spike (mV). Default: 18.51.
th_voltage_index : float, optional
Voltage-dependent threshold leak rate (/ms). Default: 0.005.
th_voltage_decay : float, optional
Voltage-dependent threshold decay rate (/ms). Default: 0.09.
asc_init : Sequence[float], optional
Initial values of after-spike currents (pA). Default: (0.0, 0.0).
asc_decay : Sequence[float], optional
ASC decay rates (/ms). Default: (0.003, 0.1).
asc_amps : Sequence[float], optional
ASC amplitudes added on spike (pA). Default: (-9.18, -198.94).
asc_r : Sequence[float], optional
ASC fraction coefficients (dimensionless). Default: (1.0, 1.0).
tau_syn : Sequence[float], optional
Synaptic alpha-function time constants (ms), one per receptor port.
Default: (2.0,).
spike_dependent_threshold : bool, optional
Enable biologically defined reset rules (GLIF2/4/5). Default: False.
after_spike_currents : bool, optional
Enable after-spike currents (GLIF3/4/5). Default: False.
adapting_threshold : bool, optional
Enable voltage-dependent threshold (GLIF5). Default: False.
I_e : ArrayLike, optional
Constant external current. Default: 0.0 pA.
gsl_error_tol : ArrayLike, optional
Unitless local RKF45 error tolerance, broadcastable and strictly positive.
Default: 1e-6.
V_initializer : Callable, optional
Membrane potential initializer. Default: Constant(E_L).
spk_fun : Callable, optional
Surrogate gradient function for spike generation. Default: ReluGrad().
spk_reset : str, optional
Spike reset mode: 'hard' or 'soft'. Default: 'hard'.
ref_var : bool, optional
If ``True``, allocate and expose ``self.refractory`` state.
name : str, optional
Name of the neuron group.
Parameter Mapping
-----------------
=============================== =================== ========================================== =====================================================
**Parameter** **Default** **Math equivalent** **Description**
=============================== =================== ========================================== =====================================================
``in_size`` (required) Population shape
``g`` 9.43 nS :math:`g` Membrane (leak) conductance
``E_L`` -78.85 mV :math:`E_L` Resting membrane potential
``V_th`` -51.68 mV :math:`V_\mathrm{th}` Instantaneous threshold (absolute)
``C_m`` 58.72 pF :math:`C_\mathrm{m}` Membrane capacitance
``t_ref`` 3.75 ms :math:`t_\mathrm{ref}` Absolute refractory period
``V_reset`` -78.85 mV :math:`V_\mathrm{reset}` Reset potential (absolute; GLIF1/3)
``th_spike_add`` 0.37 mV :math:`\Delta\theta_s` Threshold additive constant after spike
``th_spike_decay`` 0.009 /ms :math:`b_s` Spike threshold decay rate
``voltage_reset_fraction`` 0.20 :math:`f_v` Voltage fraction after spike
``voltage_reset_add`` 18.51 mV :math:`V_\mathrm{add}` Voltage additive after spike
``th_voltage_index`` 0.005 /ms :math:`a_v` Voltage-dependent threshold leak
``th_voltage_decay`` 0.09 /ms :math:`b_v` Voltage-dependent threshold decay rate
``asc_init`` (0.0, 0.0) pA Initial values of ASC
``asc_decay`` (0.003, 0.1) /ms :math:`k_j` ASC time constants (decay rates)
``asc_amps`` (-9.18, -198.94) pA :math:`\Delta I_j` ASC amplitudes on spike
``asc_r`` (1.0, 1.0) :math:`r_j` ASC fraction coefficient
``tau_syn`` (2.0,) ms :math:`\tau_{\mathrm{syn},k}` Synaptic alpha-function time constants
``spike_dependent_threshold`` False Enable biologically defined reset (GLIF2/4/5)
``after_spike_currents`` False Enable after-spike currents (GLIF3/4/5)
``adapting_threshold`` False Enable voltage-dependent threshold (GLIF5)
``I_e`` 0.0 pA :math:`I_e` Constant external current
``gsl_error_tol`` 1e-6 -- Local absolute tolerance for RKF45 error estimate
``V_initializer`` Constant(E_L) Membrane potential initializer
``spk_fun`` ReluGrad() Surrogate spike function
``spk_reset`` ``'hard'`` Reset mode
``ref_var`` False If True, expose boolean refractory state
=============================== =================== ========================================== =====================================================
Attributes
----------
V : HiddenState
Membrane potential :math:`V_\mathrm{m}` (absolute, mV).
y1 : list of HiddenState
Synaptic current derivative states (pA), one per receptor port.
y2 : list of HiddenState
Synaptic current states (pA), one per receptor port.
last_spike_time : ShortTermState
Last spike time for each neuron (ms).
refractory_step_count : ShortTermState
Remaining refractory grid steps (int32).
integration_step : ShortTermState
Persistent RKF45 substep size estimate (ms).
I_stim : ShortTermState
Buffered external current for next step (pA).
_ASCurrents : numpy.ndarray
After-spike current values (pA). Shape: (n_asc, \*varshape).
_ASCurrents_sum : numpy.ndarray
Sum of after-spike currents (pA). Shape: (\*varshape).
_threshold : numpy.ndarray
Total threshold (relative to E_L, in mV). Shape: (\*varshape).
_threshold_spike : numpy.ndarray
Spike component of threshold (mV). Shape: (\*varshape).
_threshold_voltage : numpy.ndarray
Voltage component of threshold (mV). Shape: (\*varshape).
refractory : ShortTermState
Optional boolean refractory indicator, available only when
``ref_var=True``.
Raises
------
ValueError
If invalid model mechanism combination is specified.
ValueError
If V_reset >= V_th (reset must be below threshold).
ValueError
If capacitance, conductance, or time constants are not positive.
ValueError
If voltage_reset_fraction not in [0, 1].
ValueError
If asc_r values not in [0, 1].
ValueError
If ASC parameter arrays have mismatched lengths.
Notes
-----
- Default parameter values are from GLIF Model 5 of Cell 490626718 from the
`Allen Cell Type Database <https://celltypes.brain-map.org>`_.
- Parameters ``V_th`` and ``V_reset`` are specified in absolute mV.
Internally, membrane potential is tracked relative to ``E_L``, matching
NEST's convention.
- For models with spike-dependent threshold (GLIF2/4/5), the reset
condition should satisfy:
.. math::
E_L + f_v \cdot (V_{th} - E_L) + V_{add} < V_{th} + \Delta\theta_s
Otherwise the neuron may spike continuously.
- Unlike ``glif_cond`` which uses an RKF45 ODE integrator, ``glif_psc``
uses exact integration via propagator matrices for the linear
subthreshold dynamics, matching NEST's implementation.
- If ``tau_m`` is very close to ``tau_syn``, the model numerically behaves
as if they are equal, to avoid numerical instabilities (see NEST
IAF_Integration_Singularity notebook).
- Synaptic inputs are delivered to receptor ports starting from port 0.
Register inputs with keys like 'receptor_0', 'receptor_1', etc., via
the ``add_delta_input`` method. Inputs without a receptor label default
to receptor port 0.
Examples
--------
**GLIF Model 1 (Basic LIF)**:
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> with u.context(dt=0.1 * u.ms):
... model = bp.glif_psc(100, spike_dependent_threshold=False,
... after_spike_currents=False, adapting_threshold=False)
... model.init_all_states()
... output = model(350 * u.pA)
**GLIF Model 5 (Full Model with Adaptation)**:
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> with u.context(dt=0.1 * u.ms):
... model = bp.glif_psc(100, spike_dependent_threshold=True,
... after_spike_currents=True, adapting_threshold=True)
... model.init_all_states()
... output = model(200 * u.pA)
**Multi-Receptor Configuration**:
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> with u.context(dt=0.1 * u.ms):
... model = bp.glif_psc(100, tau_syn=(2.0, 5.0, 10.0))
... model.init_all_states()
... # Register inputs to different receptor ports
... model.add_delta_input('exc_receptor_0', lambda: 10 * u.pA)
... model.add_delta_input('inh_receptor_1', lambda: -5 * u.pA)
References
----------
.. [1] Teeter C, Iyer R, Menon V, Gouwens N, Feng D, Berg J, Szafer A,
Cain N, Zeng H, Hawrylycz M, Koch C, & Mihalas S (2018).
Generalized leaky integrate-and-fire models classify multiple neuron
types. Nature Communications 9:709.
.. [2] Meffin H, Burkitt AN, Grayden DB (2004). An analytical model for
the large, fluctuating synaptic conductance state typical of
neocortical neurons in vivo. J. Comput. Neurosci. 16:159-175.
.. [3] NEST Simulator ``glif_psc`` model documentation and C++ source:
``models/glif_psc.h`` and ``models/glif_psc.cpp``.
See Also
--------
glif_cond : Conductance-based GLIF model with RKF45 integration.
gif_psc_exp_multisynapse : Generalized IF with exponential synapses.
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
g: ArrayLike = 9.43 * u.nS,
E_L: ArrayLike = -78.85 * u.mV,
V_th: ArrayLike = -51.68 * u.mV,
C_m: ArrayLike = 58.72 * u.pF,
t_ref: ArrayLike = 3.75 * u.ms,
V_reset: ArrayLike = -78.85 * u.mV,
th_spike_add: float = 0.37, # mV
th_spike_decay: float = 0.009, # 1/ms
voltage_reset_fraction: float = 0.20,
voltage_reset_add: float = 18.51, # mV
th_voltage_index: float = 0.005, # 1/ms
th_voltage_decay: float = 0.09, # 1/ms
asc_init: Sequence[float] = (0.0, 0.0), # pA
asc_decay: Sequence[float] = (0.003, 0.1), # 1/ms
asc_amps: Sequence[float] = (-9.18, -198.94), # pA
asc_r: Sequence[float] = (1.0, 1.0),
tau_syn: Sequence[float] = (2.0,), # ms
spike_dependent_threshold: bool = False,
after_spike_currents: bool = False,
adapting_threshold: bool = False,
I_e: ArrayLike = 0.0 * u.pA,
gsl_error_tol: ArrayLike = 1e-6,
V_initializer: Callable = None,
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
ref_var: bool = False,
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
# Store membrane parameters
self.g_m = braintools.init.param(g, self.varshape)
self.E_L = braintools.init.param(E_L, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
# V_th and V_reset are absolute; store th_inf_ relative to E_L (like NEST)
self.V_th = braintools.init.param(V_th, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
# Scalar GLIF parameters (unitless floats in NEST units)
self.th_spike_add = float(th_spike_add)
self.th_spike_decay = float(th_spike_decay)
self.voltage_reset_fraction = float(voltage_reset_fraction)
self.voltage_reset_add = float(voltage_reset_add)
self.th_voltage_index = float(th_voltage_index)
self.th_voltage_decay = float(th_voltage_decay)
# ASC parameters (lists of floats)
self.asc_init = tuple(float(x) for x in asc_init)
self.asc_decay = tuple(float(x) for x in asc_decay)
self.asc_amps = tuple(float(x) for x in asc_amps)
self.asc_r = tuple(float(x) for x in asc_r)
# Synaptic parameters (lists)
self.tau_syn = tuple(float(x) for x in tau_syn)
# Model mechanism flags
self.has_theta_spike = bool(spike_dependent_threshold)
self.has_asc = bool(after_spike_currents)
self.has_theta_voltage = bool(adapting_threshold)
# Default V_initializer to E_L
if V_initializer is None:
V_initializer = braintools.init.Constant(E_L)
self.V_initializer = V_initializer
self._n_receptors = len(self.tau_syn)
self.gsl_error_tol = gsl_error_tol
self.ref_var = ref_var
self._validate_parameters()
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
@property
def n_receptors(self):
r"""Number of synaptic receptor ports.
Returns
-------
int
Number of independent receptor ports, determined by the length
of the ``tau_syn`` parameter. Each receptor port has its own
synaptic time constant and independent alpha-function dynamics.
"""
return self._n_receptors
def _validate_parameters(self):
r"""Validate model parameters against NEST constraints.
Raises
------
ValueError
If parameter inequalities or positivity constraints are violated.
"""
# Check valid model mechanism combinations
s, a, v = self.has_theta_spike, self.has_asc, self.has_theta_voltage
valid_combos = [
(False, False, False), # GLIF1
(True, False, False), # GLIF2
(False, True, False), # GLIF3
(True, True, False), # GLIF4
(True, True, True), # GLIF5
]
if (s, a, v) not in valid_combos:
raise ValueError(
"Incorrect model mechanism combination. "
"Valid combinations: GLIF1(FFF), GLIF2(TFF), GLIF3(FTF), "
"GLIF4(TTF), GLIF5(TTT). Got spike_dependent_threshold=%s, "
"after_spike_currents=%s, adapting_threshold=%s." % (s, a, v)
)
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
return
# V_reset (relative) < V_th (relative) — both relative to E_L
E_L_val = self.E_L
V_reset_rel = self.V_reset - E_L_val
V_th_rel = self.V_th - E_L_val
if np.any(V_reset_rel >= V_th_rel):
raise ValueError("Reset potential must be smaller than threshold.")
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError("Capacitance must be strictly positive.")
if np.any(self.g_m <= 0.0 * u.nS):
raise ValueError("Membrane conductance must be strictly positive.")
if np.any(self.t_ref <= 0.0 * u.ms):
raise ValueError("Refractory time constant must be strictly positive.")
if self.has_theta_spike:
if self.th_spike_decay <= 0.0:
raise ValueError("Spike induced threshold time constant must be strictly positive.")
if not (0.0 <= self.voltage_reset_fraction <= 1.0):
raise ValueError("Voltage fraction coefficient following spike must be within [0.0, 1.0].")
if self.has_asc:
n = len(self.asc_decay)
if not (len(self.asc_init) == n and len(self.asc_amps) == n and len(self.asc_r) == n):
raise ValueError(
"All after spike current parameters (asc_init, asc_decay, asc_amps, asc_r) "
"must have the same size."
)
for k_val in self.asc_decay:
if k_val <= 0.0:
raise ValueError("After-spike current time constant must be strictly positive.")
for r_val in self.asc_r:
if not (0.0 <= r_val <= 1.0):
raise ValueError(
"After spike current fraction coefficients r must be within [0.0, 1.0]."
)
if self.has_theta_voltage:
if self.th_voltage_decay <= 0.0:
raise ValueError("Voltage-induced threshold time constant must be strictly positive.")
for tau in self.tau_syn:
if tau <= 0.0:
raise ValueError("All synaptic time constants must be strictly positive.")
if np.any(self.gsl_error_tol <= 0.0):
raise ValueError('The gsl_error_tol must be strictly positive.')
[docs]
def init_state(self, **kwargs):
r"""Initialize persistent and short-term state variables.
All GLIF-specific state variables are stored as JAX ``HiddenState`` arrays,
and pre-computed decay constants are stored as Python floats.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
dt = brainstate.environ.get_dt()
dt_ms = float(np.asarray(u.get_mantissa(dt / u.ms)))
V = braintools.init.param(self.V_initializer, self.varshape)
self.V = brainstate.HiddenState(V)
# Per-receptor alpha-function current states: y1 (rate, pA/ms), y2 (current, pA)
self.y1 = [
brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0 * u.pA / u.ms), self.varshape)
)
for _ in range(self._n_receptors)
]
self.y2 = [
brainstate.HiddenState(
braintools.init.param(braintools.init.Constant(0.0 * u.pA), self.varshape)
)
for _ in range(self._n_receptors)
]
self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms))
self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype))
self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype))
# GLIF-specific state as HiddenState (JAX-traceable, compatible with for_loop)
n_asc = len(self.asc_decay)
self._asc_states = [
brainstate.HiddenState(jnp.full(self.varshape, self.asc_init[a], dtype=dftype))
for a in range(n_asc)
]
# Threshold components (relative to E_L) as HiddenState
E_L_mV = float(np.asarray(u.get_mantissa(self.E_L / u.mV)))
th_inf = float(np.asarray(u.get_mantissa(self.V_th / u.mV))) - E_L_mV
self._th_inf = th_inf
self._threshold_spike_state = brainstate.HiddenState(
jnp.zeros(self.varshape, dtype=dftype)
)
self._threshold_voltage_state = brainstate.HiddenState(
jnp.zeros(self.varshape, dtype=dftype)
)
self._threshold_state = brainstate.HiddenState(
jnp.full(self.varshape, th_inf, dtype=dftype)
)
# Pre-compute decay rates (Python float constants, computed once per init_state call)
G = float(np.asarray(u.get_mantissa(self.g_m / u.nS)))
C_m_val = float(np.asarray(u.get_mantissa(self.C_m / u.pF)))
t_ref_ms = float(np.asarray(u.get_mantissa(self.t_ref / u.ms)))
if self.has_theta_spike:
self._decay_spike = np.exp(-self.th_spike_decay * dt_ms)
self._decay_spike_refr = np.exp(-self.th_spike_decay * t_ref_ms)
if self.has_asc:
self._asc_decay_rates = [np.exp(-self.asc_decay[a] * dt_ms) for a in range(n_asc)]
self._asc_stable_coeff = [
((1.0 / self.asc_decay[a]) / dt_ms) * (1.0 - self._asc_decay_rates[a])
for a in range(n_asc)
]
self._asc_refr_decay_rates = [
self.asc_r[a] * np.exp(-self.asc_decay[a] * t_ref_ms)
for a in range(n_asc)
]
if self.has_theta_voltage:
self._potential_decay_rate = np.exp(-G * dt_ms / C_m_val)
self._theta_voltage_decay_rate_inv = 1.0 / np.exp(self.th_voltage_decay * dt_ms)
self._phi = self.th_voltage_index / (self.th_voltage_decay - G / C_m_val)
self._abpara_ratio = self.th_voltage_index / self.th_voltage_decay
# Pre-compute exact propagator matrices (NEST IAFPropagatorAlpha scheme)
tau_m = C_m_val / G # membrane time constant in ms
self._P33 = np.exp(-dt_ms / tau_m)
self._P30 = (1.0 / C_m_val) * (1.0 - self._P33) * tau_m # mV/pA
self._P11 = []
self._P21 = []
self._P22 = []
self._P31 = []
self._P32 = []
self._PSCInitialValues = []
for k in range(self._n_receptors):
p11 = np.exp(-dt_ms / self.tau_syn[k])
self._P11.append(p11)
self._P22.append(p11)
self._P21.append(dt_ms * p11)
p31, p32 = alpha_propagator_p31_p32(self.tau_syn[k], tau_m, C_m_val, dt_ms)
self._P31.append(float(p31))
self._P32.append(float(p32))
self._PSCInitialValues.append(np.e / self.tau_syn[k])
if self.ref_var:
refractory = braintools.init.param(braintools.init.Constant(False), self.varshape)
self.refractory = brainstate.ShortTermState(refractory)
# Backward-compatible properties for threshold components
@property
def _threshold(self):
return self._threshold_state.value
@property
def _threshold_spike(self):
return self._threshold_spike_state.value
@property
def _threshold_voltage(self):
return self._threshold_voltage_state.value
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Generate spike output via surrogate gradient function.
Applies the surrogate gradient function to a normalized voltage signal.
The voltage is linearly scaled such that ``V_th`` maps to 1 and
``V_reset`` maps to 0, providing a normalized input for the surrogate
function.
Parameters
----------
V : ArrayLike, optional
Membrane potential (with units). If None, uses current ``self.V.value``.
Returns
-------
spike : jax.numpy.ndarray
Spike output (float32). Shape matches the neuron population.
Forward pass produces values in [0, 1]; backward pass uses the
surrogate gradient specified by ``spk_fun``.
Notes
-----
- This method is called internally by the base ``Neuron`` class and
is typically not invoked directly by users.
- The surrogate function enables gradient-based learning by providing
a differentiable approximation to the Heaviside step function.
"""
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
return self.spk_fun(v_scaled)
def _collect_receptor_delta_inputs(self):
r"""Collect delta inputs per receptor port using label-based routing.
Returns a list of current jumps (pA) for each receptor port, JIT-compatible.
"""
dftype = brainstate.environ.dftype()
return [
self.sum_delta_inputs(
jnp.zeros(self.varshape, dtype=dftype) * u.pA,
label=f'receptor_{k}',
)
for k in range(self._n_receptors)
]
[docs]
def update(self, x=0.0 * u.pA):
r"""Perform a single simulation step using exact propagator matrices.
Implements the NEST ``glif_psc`` update using the exact
IAFPropagatorAlpha integration scheme. All GLIF-specific discrete
updates (threshold decay, ASC, voltage-dependent threshold) are
applied as vectorised JAX operations, making this method compatible
with ``brainstate.transform.for_loop``.
Parameters
----------
x : ArrayLike, optional
External current input (pA), applied with one-step delay. Default: 0.0 pA.
Returns
-------
spike : jax.Array
Binary spike tensor (float32), shape ``(*varshape)``.
"""
t = brainstate.environ.get('t')
dt_q = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
# Python-level constants (concrete, not JAX-traced)
E_L_mV = float(np.asarray(u.get_mantissa(self.E_L / u.mV)))
I_e_pA = float(np.asarray(u.get_mantissa(self.I_e / u.pA)))
V_reset_rel = float(np.asarray(u.get_mantissa(self.V_reset / u.mV))) - E_L_mV
G_nS = float(np.asarray(u.get_mantissa(self.g_m / u.nS)))
# JAX state (traced under for_loop)
r = self.refractory_step_count.value # int array, varshape
i_stim_pA = u.get_mantissa(self.I_stim.value / u.pA) # float array, varshape
# V_rel (old, before this step's update)
V_rel = jax.lax.stop_gradient(
u.get_mantissa(self.V.value / u.mV) - E_L_mV
) # plain JAX array, mV relative to E_L
# Buffer new external current (one-step delay)
new_i_stim_q = self.sum_current_inputs(x, self.V.value)
is_refractory = r > 0
i_ext = I_e_pA + i_stim_pA # pA, plain JAX array
n_asc = len(self.asc_decay)
# 1. Spike threshold decay (non-refractory only)
if self.has_theta_spike:
tspk = self._threshold_spike_state.value
tspk = jnp.where(is_refractory, tspk, tspk * self._decay_spike)
else:
tspk = jnp.zeros(self.varshape, dtype=dftype)
# 2. ASC stable-coeff sum + decay (non-refractory only)
if self.has_asc:
asc_sum_new = jnp.zeros(self.varshape, dtype=dftype)
asc_decayed = []
for a in range(n_asc):
asc_a = self._asc_states[a].value
asc_sum_new = asc_sum_new + self._asc_stable_coeff[a] * asc_a
asc_decayed.append(asc_a * self._asc_decay_rates[a])
asc_sum = jnp.where(is_refractory, jnp.zeros(self.varshape, dtype=dftype), asc_sum_new)
else:
asc_sum = jnp.zeros(self.varshape, dtype=dftype)
asc_decayed = []
# 3. Voltage-dependent threshold (non-refractory only, using old V_rel)
if self.has_theta_voltage:
tvlt = self._threshold_voltage_state.value
beta = (i_ext + asc_sum) / G_nS # pA/nS = mV
tvlt_new = (
self._phi * (V_rel - beta) * self._potential_decay_rate
+ self._theta_voltage_decay_rate_inv * (
tvlt
- self._phi * (V_rel - beta)
- self._abpara_ratio * beta
)
+ self._abpara_ratio * beta
)
tvlt = jnp.where(is_refractory, tvlt, tvlt_new)
else:
tvlt = jnp.zeros(self.varshape, dtype=dftype)
# 4. Total threshold
threshold = tspk + tvlt + self._th_inf
# 5. V update via exact propagator
y1_old = [u.get_mantissa(self.y1[k].value / (u.pA / u.ms)) for k in range(self._n_receptors)]
y2_old = [u.get_mantissa(self.y2[k].value / u.pA) for k in range(self._n_receptors)]
v_new = V_rel * self._P33 + (i_ext + asc_sum) * self._P30
for k in range(self._n_receptors):
v_new = v_new + self._P31[k] * y1_old[k] + self._P32[k] * y2_old[k]
# Clamp refractory neurons to old V_rel
v_new = jnp.where(is_refractory, V_rel, v_new)
# 6. Spike check (non-refractory only)
spiked = (v_new > threshold) & ~is_refractory
# 7. ASC reset on spike
if self.has_asc:
for a in range(n_asc):
asc_a = self._asc_states[a].value
asc_reset = self.asc_amps[a] + asc_decayed[a] * self._asc_refr_decay_rates[a]
self._asc_states[a].value = jnp.where(
spiked, asc_reset,
jnp.where(is_refractory, asc_a, asc_decayed[a])
)
# 8. Voltage reset on spike
if not self.has_theta_spike:
# GLIF1/3: simple reset
V_final_rel = jnp.where(spiked, V_reset_rel, v_new)
else:
# GLIF2/4/5: biologically defined reset
V_reset_bio = self.voltage_reset_fraction * V_rel + self.voltage_reset_add
V_final_rel = jnp.where(spiked, V_reset_bio, v_new)
# 9. Theta_spike reset on spike
tspk_reset = tspk * self._decay_spike_refr + self.th_spike_add
tspk = jnp.where(spiked, tspk_reset, tspk)
threshold = jnp.where(spiked, tspk + tvlt + self._th_inf, threshold)
# 10. Refractory counter
r_new = jnp.where(
spiked, self.ref_count,
jnp.where(is_refractory, r - 1, r)
)
# 11. Y1/Y2 propagator update (unconditional — all neurons, including refractory)
y1_new = [self._P11[k] * y1_old[k] for k in range(self._n_receptors)]
y2_new = [self._P21[k] * y1_old[k] + self._P22[k] * y2_old[k]
for k in range(self._n_receptors)]
# 12. Collect and apply synaptic delta inputs to y1
dy_input = self._collect_receptor_delta_inputs()
for k in range(self._n_receptors):
w_k = u.get_mantissa(dy_input[k] / u.pA) # weight in pA
y1_new[k] = y1_new[k] + self._PSCInitialValues[k] * w_k
# ---- Write back all state ----
self.V.value = (V_final_rel + E_L_mV) * u.mV
for k in range(self._n_receptors):
self.y1[k].value = y1_new[k] * (u.pA / u.ms)
self.y2[k].value = y2_new[k] * u.pA
self._threshold_spike_state.value = tspk
self._threshold_voltage_state.value = tvlt
self._threshold_state.value = threshold
self.refractory_step_count.value = jnp.asarray(r_new, dtype=ditype)
self.I_stim.value = new_i_stim_q + u.math.zeros(self.varshape) * u.pA
last_spike_time = u.math.where(spiked, t + dt_q, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0)
return jnp.asarray(spiked, dtype=jnp.float32)