# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable, Optional, Sequence
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
from ._utils import is_tracer
__all__ = [
'gif_psc_exp',
]
class _AdaptElemsRow:
"""Mutable row-view into an adaptation element ShortTermState array.
Returned by ``_AdaptElems.__getitem__``. Supports both read
(``row[j]``) and write (``row[j] = val``) which updates the
underlying :class:`brainstate.ShortTermState` in-place via JAX's
functional ``.at[i, j].set(val)`` API.
"""
def __init__(self, state, row_idx):
self._state = state
self._idx = row_idx
def __getitem__(self, idx):
return self._state.value[self._idx][idx]
def __setitem__(self, idx, value):
self._state.value = self._state.value.at[self._idx, idx].set(value)
class _AdaptElems:
"""Mutable wrapper around a ShortTermState holding adaptation elements.
Shape of the underlying array is ``(n_elems, *varshape)``.
Supports indexing (``elems[i]``) which returns an :class:`_AdaptElemsRow`
supporting further item-read/write.
"""
def __init__(self, state):
self._state = state
def __getitem__(self, idx):
return _AdaptElemsRow(self._state, idx)
def __setitem__(self, idx, value):
self._state.value = self._state.value.at[idx].set(value)
def __len__(self):
return self._state.value.shape[0]
class gif_psc_exp(NESTNeuron):
r"""Current-based generalized integrate-and-fire neuron (GIF) model.
This is a brainpy.state re-implementation of the NEST simulator's ``gif_psc_exp``
model according to Mensi et al. (2012) [1]_ and Pozzorini et al. (2015) [2]_, using
NEST-standard parameterization and exact integration.
The GIF model features both spike-triggered adaptation currents and a dynamic
firing threshold for spike-frequency adaptation. It generates spikes stochastically
based on a point process with intensity that depends on the distance between the
membrane potential and the adaptive threshold.
**1. Mathematical Model**
**1.1 Membrane Dynamics**
The membrane potential :math:`V` is governed by:
.. math::
C_\mathrm{m} \frac{dV(t)}{dt} = -g_\mathrm{L}(V(t) - E_\mathrm{L})
- \eta_1(t) - \eta_2(t) - \ldots - \eta_n(t) + I(t)
where:
- :math:`C_\mathrm{m}` is the membrane capacitance
- :math:`g_\mathrm{L}` is the leak conductance
- :math:`E_\mathrm{L}` is the leak reversal potential
- :math:`\eta_i(t)` are spike-triggered currents (stc)
- :math:`I(t) = I_\mathrm{syn,ex}(t) + I_\mathrm{syn,in}(t) + I_\mathrm{e} + I_\mathrm{stim}(t)`
**1.2 Synaptic Currents**
Synaptic currents decay exponentially:
.. math::
\frac{dI_{\mathrm{syn,ex}}}{dt} = -\frac{I_{\mathrm{syn,ex}}}{\tau_{\mathrm{syn,ex}}},
\qquad
\frac{dI_{\mathrm{syn,in}}}{dt} = -\frac{I_{\mathrm{syn,in}}}{\tau_{\mathrm{syn,in}}}
Incoming spike weights (in pA) are routed by sign: positive weights to
:math:`I_{\mathrm{syn,ex}}`, negative to :math:`I_{\mathrm{syn,in}}`.
**1.3 Spike-Triggered Currents (STC)**
Each spike-triggered current element :math:`\eta_i` evolves as:
.. math::
\tau_{\eta_i} \frac{d\eta_i}{dt} = -\eta_i
On spike emission:
.. math::
\eta_i \leftarrow \eta_i + q_{\eta_i}
**1.4 Spike-Frequency Adaptation (SFA)**
The neuron fires stochastically with intensity:
.. math::
\lambda(t) = \lambda_0 \cdot \exp\left(\frac{V(t) - V_T(t)}{\Delta_V}\right)
where the dynamic threshold :math:`V_T(t)` is:
.. math::
V_T(t) = V_{T^*} + \gamma_1(t) + \gamma_2(t) + \ldots + \gamma_m(t)
Each adaptation element :math:`\gamma_i` evolves as:
.. math::
\tau_{\gamma_i} \frac{d\gamma_i}{dt} = -\gamma_i
On spike emission:
.. math::
\gamma_i \leftarrow \gamma_i + q_{\gamma_i}
**1.5 Stochastic Spiking**
The probability of firing within a time step :math:`dt` is:
.. math::
P(\text{spike}) = 1 - \exp(-\lambda(t) \cdot dt)
A uniformly distributed random number is drawn each (non-refractory) time step
and compared to this probability to determine spike emission.
**1.6 Refractory Period**
After a spike, the neuron enters an absolute refractory period of duration
:math:`t_\mathrm{ref}`. During this period:
- The refractory counter decrements each step
- :math:`V_\mathrm{m}` is clamped to :math:`V_\mathrm{reset}`
- Synaptic currents continue to decay and receive inputs
- No spike checks are performed
**2. Numerical Integration**
The model uses exact matrix-exponential integration, matching NEST's update
order precisely. The discrete-time update per simulation step is:
1. **STC/SFA totals**: Sum adaptation elements (before decay), then decay.
2. **Synaptic decay**: :math:`I_\mathrm{syn} \leftarrow I_\mathrm{syn} \cdot e^{-dt/\tau}`.
3. **Spike weights**: Add arriving spike weights to :math:`I_\mathrm{syn,ex}` / :math:`I_\mathrm{syn,in}`.
4. **V update**: If not refractory, apply exact propagator using post-weight synaptic currents.
If refractory, clamp :math:`V` to :math:`V_\mathrm{reset}` and decrement counter.
5. **Store I_stim**: Buffer external input for next step (NEST ring buffer semantics).
Parameters
----------
in_size : int, tuple of int
Shape of the neuron population.
g_L : ArrayLike, optional
Leak conductance. Default: 4.0 nS.
E_L : ArrayLike, optional
Leak reversal potential. Default: -70.0 mV.
C_m : ArrayLike, optional
Membrane capacitance. Default: 80.0 pF.
V_reset : ArrayLike, optional
Reset potential after spike. Default: -55.0 mV.
Delta_V : ArrayLike, optional
Stochasticity level. Default: 0.5 mV.
V_T_star : ArrayLike, optional
Base firing threshold. Default: -35.0 mV.
lambda_0 : float, optional
Stochastic intensity at threshold in 1/s. Default: 1.0 /s.
t_ref : ArrayLike, optional
Absolute refractory period. Default: 4.0 ms.
tau_syn_ex : ArrayLike, optional
Excitatory synaptic time constant. Default: 2.0 ms.
tau_syn_in : ArrayLike, optional
Inhibitory synaptic time constant. Default: 2.0 ms.
I_e : ArrayLike, optional
Constant external current. Default: 0.0 pA.
tau_sfa : Sequence[float], optional
SFA time constants in ms. Default: () (no SFA).
q_sfa : Sequence[float], optional
SFA jump values in mV. Default: () (no SFA).
tau_stc : Sequence[float], optional
STC time constants in ms. Default: () (no STC).
q_stc : Sequence[float], optional
STC jump values in pA. Default: () (no STC).
rng_key : jax.Array, optional
JAX PRNG key for stochastic spiking. Default: None (seed 0).
V_initializer : Callable, optional
Initializer for membrane potential. Default: Constant(-70 mV).
spk_fun : Callable, optional
Surrogate gradient function. Default: ReluGrad().
spk_reset : str, optional
Spike reset mode. Default: 'hard'.
ref_var : bool, optional
If True, expose boolean refractory state. Default: False.
name : str, optional
Name of the neuron group. Default: None.
References
----------
.. [1] Mensi S et al. (2012). Parameter extraction and classification of three
cortical neuron types. Journal of Neurophysiology, 107(6):1756-1775.
.. [2] Pozzorini C et al. (2015). Automated high-throughput characterization of
single neurons. PLoS Computational Biology, 11(6), e1004275.
.. [3] NEST Simulator ``gif_psc_exp`` model: ``models/gif_psc_exp.h``.
See Also
--------
gif_cond_exp : Conductance-based GIF model.
iaf_psc_exp : Simple IAF neuron with exponential synapses.
gif_psc_exp_multisynapse : GIF model with multiple receptor ports.
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
g_L: ArrayLike = 4.0 * u.nS,
E_L: ArrayLike = -70.0 * u.mV,
C_m: ArrayLike = 80.0 * u.pF,
V_reset: ArrayLike = -55.0 * u.mV,
Delta_V: ArrayLike = 0.5 * u.mV,
V_T_star: ArrayLike = -35.0 * u.mV,
lambda_0: float = 1.0, # 1/s, as in NEST Python interface
t_ref: ArrayLike = 4.0 * u.ms,
tau_syn_ex: ArrayLike = 2.0 * u.ms,
tau_syn_in: ArrayLike = 2.0 * u.ms,
I_e: ArrayLike = 0.0 * u.pA,
tau_sfa: Sequence[float] = (), # ms values
q_sfa: Sequence[float] = (), # mV values
tau_stc: Sequence[float] = (), # ms values
q_stc: Sequence[float] = (), # pA values
rng_key: Optional[jax.Array] = None,
V_initializer: Callable = braintools.init.Constant(-70.0 * u.mV),
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
ref_var: bool = False,
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
# Membrane parameters
self.g_L = braintools.init.param(g_L, self.varshape)
self.E_L = braintools.init.param(E_L, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.Delta_V = braintools.init.param(Delta_V, self.varshape)
self.V_T_star = braintools.init.param(V_T_star, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
# Synaptic parameters
self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
# Stochastic spiking: lambda_0 in 1/s, store as 1/ms internally
self.lambda_0 = lambda_0 / 1000.0 # convert from 1/s to 1/ms
# Adaptation parameters (stored as plain Python tuples of floats in ms/mV/pA)
self.tau_sfa = tuple(float(x) for x in tau_sfa)
self.q_sfa = tuple(float(x) for x in q_sfa)
self.tau_stc = tuple(float(x) for x in tau_stc)
self.q_stc = tuple(float(x) for x in q_stc)
if len(self.tau_sfa) != len(self.q_sfa):
raise ValueError(
f"'tau_sfa' and 'q_sfa' must have the same length. "
f"Got {len(self.tau_sfa)} and {len(self.q_sfa)}."
)
if len(self.tau_stc) != len(self.q_stc):
raise ValueError(
f"'tau_stc' and 'q_stc' must have the same length. "
f"Got {len(self.tau_stc)} and {len(self.q_stc)}."
)
# RNG key for stochastic spiking
self._rng_key = rng_key
# Initializers
self.V_initializer = V_initializer
self.ref_var = ref_var
self._validate_parameters()
# Refractory counter (integer steps)
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
def _sum_signed_delta_inputs(self):
r"""Route delta inputs by sign: positive -> excitatory, negative -> inhibitory."""
w_ex = u.math.zeros_like(self.I_syn_ex.value)
w_in = u.math.zeros_like(self.I_syn_in.value)
if self.delta_inputs is None:
return w_ex, w_in
for key in tuple(self.delta_inputs.keys()):
out = self.delta_inputs[key]
if callable(out):
out = out()
else:
self.delta_inputs.pop(key)
zero = u.math.zeros_like(out)
w_ex = w_ex + u.math.maximum(out, zero)
w_in = w_in + u.math.minimum(out, zero)
return w_ex, w_in
def _validate_parameters(self):
r"""Validate model parameters against NEST constraints."""
if any(is_tracer(v) for v in (self.C_m, self.g_L, self.Delta_V)):
return
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.g_L <= 0.0 * u.nS):
raise ValueError('Membrane conductance must be strictly positive.')
if np.any(self.Delta_V <= 0.0 * u.mV):
raise ValueError('Delta_V must be strictly positive.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory time must not be negative.')
if self.lambda_0 < 0.0:
raise ValueError('lambda_0 must not be negative.')
if np.any(self.tau_syn_ex <= 0.0 * u.ms) or \
np.any(self.tau_syn_in <= 0.0 * u.ms):
raise ValueError('Synapse time constants must be strictly positive.')
for tau in self.tau_sfa:
if tau <= 0.0:
raise ValueError('All SFA time constants must be strictly positive.')
for tau in self.tau_stc:
if tau <= 0.0:
raise ValueError('All STC time constants must be strictly positive.')
[docs]
def init_state(self, **kwargs):
r"""Initialize persistent and short-term state variables."""
ditype = brainstate.environ.ditype()
v_shape = self.varshape
n_stc = len(self.tau_stc)
n_sfa = len(self.tau_sfa)
V = braintools.init.param(self.V_initializer, v_shape)
# Force float64 precision for V and synaptic currents
V_f64 = jnp.asarray(u.get_mantissa(V / u.mV), dtype=jnp.float64) * u.mV
self.V = brainstate.HiddenState(V_f64)
self.I_syn_ex = brainstate.ShortTermState(jnp.zeros(v_shape, dtype=jnp.float64) * u.pA)
self.I_syn_in = brainstate.ShortTermState(jnp.zeros(v_shape, dtype=jnp.float64) * u.pA)
# Adaptation state: stc/sfa elements stored as float64 for exact decay.
# Shape (n_stc, *v_shape) for stc, (n_sfa, *v_shape) for sfa.
self._stc_elems_state = (
brainstate.ShortTermState(jnp.zeros((n_stc, *v_shape), dtype=jnp.float64))
if n_stc > 0 else None
)
self._sfa_elems_state = (
brainstate.ShortTermState(jnp.zeros((n_sfa, *v_shape), dtype=jnp.float64))
if n_sfa > 0 else None
)
# Extract V_T_star as float64 numpy for initializing sfa_val
V_T_star_np = np.asarray(u.get_mantissa(self.V_T_star / u.mV), dtype=np.float64)
# Total STC current and effective threshold (updated at start of each step)
self._stc_val_state = brainstate.ShortTermState(jnp.zeros(v_shape, dtype=jnp.float64))
self._sfa_val_state = brainstate.ShortTermState(
jnp.zeros(v_shape, dtype=jnp.float64) + jnp.asarray(V_T_star_np, dtype=jnp.float64)
)
self.last_spike_time = brainstate.ShortTermState(u.math.full(v_shape, -1e7 * u.ms))
self.refractory_step_count = brainstate.ShortTermState(u.math.full(v_shape, 0, dtype=ditype))
self.I_stim = brainstate.ShortTermState(jnp.zeros(v_shape, dtype=jnp.float64) * u.pA)
# RNG state as ShortTermState for JIT compatibility.
rng_init = self._rng_key if self._rng_key is not None else jax.random.PRNGKey(0)
self._rng_state = brainstate.ShortTermState(rng_init)
if self.ref_var:
refractory = braintools.init.param(braintools.init.Constant(False), v_shape)
self.refractory = brainstate.ShortTermState(refractory)
# Pre-compute exact propagator coefficients (float64, numpy)
self._precompute_propagators()
def _precompute_propagators(self):
"""Pre-compute exact matrix-exponential propagator coefficients.
Matches NEST's IAFPropagatorExp approach. All values are float64 numpy
arrays computed once at init and reused every update step.
"""
dt = brainstate.environ.get_dt()
dt_ms = float(np.asarray(u.get_mantissa(dt / u.ms)))
tau_syn_ex_ms = np.asarray(u.get_mantissa(self.tau_syn_ex / u.ms), dtype=np.float64)
tau_syn_in_ms = np.asarray(u.get_mantissa(self.tau_syn_in / u.ms), dtype=np.float64)
C_m_pF = np.asarray(u.get_mantissa(self.C_m / u.pF), dtype=np.float64)
g_L_nS = np.asarray(u.get_mantissa(self.g_L / u.nS), dtype=np.float64)
tau_m_ms = C_m_pF / g_L_nS
self._dt_ms = dt_ms
# Membrane propagators
self._P33 = np.exp(-dt_ms / tau_m_ms)
self._P30 = -1.0 / C_m_pF * np.expm1(-dt_ms / tau_m_ms) * tau_m_ms
self._P31 = -np.expm1(-dt_ms / tau_m_ms)
# Synaptic decay coefficients
self._P11_ex = np.exp(-dt_ms / tau_syn_ex_ms)
self._P11_in = np.exp(-dt_ms / tau_syn_in_ms)
# Synaptic-to-membrane coupling propagators
self._P21_ex = self._propagator_exp(tau_syn_ex_ms, tau_m_ms, C_m_pF, dt_ms)
self._P21_in = self._propagator_exp(tau_syn_in_ms, tau_m_ms, C_m_pF, dt_ms)
# Pre-extracted parameter values as float64 numpy
self._E_L_mV_np = np.asarray(u.get_mantissa(self.E_L / u.mV), dtype=np.float64)
self._V_reset_mV_np = np.asarray(u.get_mantissa(self.V_reset / u.mV), dtype=np.float64)
self._V_T_star_mV_np = np.asarray(u.get_mantissa(self.V_T_star / u.mV), dtype=np.float64)
self._Delta_V_mV_np = np.asarray(u.get_mantissa(self.Delta_V / u.mV), dtype=np.float64)
self._I_e_pA_np = np.asarray(u.get_mantissa(self.I_e / u.pA), dtype=np.float64)
@property
def _stc_elems(self):
"""Spike-triggered current elements (n_stc, *varshape), plain float (pA)."""
return _AdaptElems(self._stc_elems_state)
@property
def _sfa_elems(self):
"""Spike-frequency adaptation elements (n_sfa, *varshape), plain float (mV)."""
return _AdaptElems(self._sfa_elems_state)
@property
def _stc_val(self):
"""Total STC current at start of last update step (*varshape), plain float (pA)."""
return self._stc_val_state.value
@property
def _sfa_val(self):
"""Effective firing threshold (V_T_star + sum of sfa) at start of last step (mV)."""
return self._sfa_val_state.value
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Generate surrogate spike output for gradient-based learning."""
V = self.V.value if V is None else V
v_scaled = (V - self.V_reset) / (self.Delta_V)
return self.spk_fun(v_scaled)
@staticmethod
def _propagator_exp(tau_syn: np.ndarray, tau_m: np.ndarray, c_m: np.ndarray, h_ms: float):
r"""Compute the propagator coefficient P21 (I_syn -> V_m) for exact integration.
Matches NEST's ``IAFPropagatorExp::evaluate()`` with singularity handling.
Parameters
----------
tau_syn : float or ndarray
Synaptic time constant in ms.
tau_m : float or ndarray
Membrane time constant in ms.
c_m : float or ndarray
Membrane capacitance in pF.
h_ms : float
Time step in ms.
Returns
-------
P21 : float or ndarray
Propagator coefficient (mV/pA when applied to pA current).
"""
with np.errstate(divide='ignore', invalid='ignore', over='ignore', under='ignore'):
beta = tau_syn * tau_m / (tau_m - tau_syn)
gamma = beta / c_m
inv_beta = (tau_m - tau_syn) / (tau_syn * tau_m)
exp_h_tau_syn = np.exp(-h_ms / tau_syn)
expm1_h_tau = np.expm1(h_ms * inv_beta)
p32_raw = gamma * exp_h_tau_syn * expm1_h_tau
normal_min = np.finfo(np.float64).tiny
regular_mask = np.isfinite(p32_raw) & (np.abs(p32_raw) >= normal_min) & (p32_raw > 0.0)
p32_singular = h_ms / c_m * np.exp(-h_ms / tau_m)
return np.where(regular_mask, p32_raw, p32_singular)
[docs]
def update(self, x=0.0 * u.pA):
r"""Advance the neuron by one simulation step.
Follows NEST's ``gif_psc_exp`` update order exactly:
1. STC/SFA totals (before decay) + decay elements.
2. Decay synaptic currents.
3. Add arriving spike weights.
4. Update V via exact propagator (non-refractory) or clamp to V_reset (refractory).
5. Stochastic spike check; on spike: STC/SFA jumps, set refractory counter.
6. Buffer external current for next step.
Parameters
----------
x : ArrayLike, optional
External current input (pA). Buffered for the NEXT time step (NEST ring
buffer semantics). Default: 0.0 pA.
Returns
-------
spike : jax.Array
Binary spike output as float array matching population shape.
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
dt_ms = self._dt_ms
v_shape = self.varshape
n_dims = len(v_shape)
n_stc = len(self.tau_stc)
n_sfa = len(self.tau_sfa)
# Read state as plain float64
V_mV = jnp.asarray(u.get_mantissa(self.V.value / u.mV), dtype=jnp.float64)
I_syn_ex_pA = jnp.asarray(u.get_mantissa(self.I_syn_ex.value / u.pA), dtype=jnp.float64)
I_syn_in_pA = jnp.asarray(u.get_mantissa(self.I_syn_in.value / u.pA), dtype=jnp.float64)
r = self.refractory_step_count.value
i_stim_pA = jnp.asarray(u.get_mantissa(self.I_stim.value / u.pA), dtype=jnp.float64)
# Buffer current input for next step (NEST ring-buffer semantics).
new_i_stim = self.sum_current_inputs(x, self.V.value)
new_i_stim_pA = jnp.asarray(u.get_mantissa(new_i_stim / u.pA), dtype=jnp.float64)
# ---- Step 1: stc/sfa totals (before decay) + decay elements ----
if n_stc > 0:
stc_elems = self._stc_elems_state.value # (n_stc, *v_shape) float64
stc_total_pA = jnp.sum(stc_elems, axis=0) # (*v_shape) float64
P_stc = jnp.array(
[np.exp(-dt_ms / tau) for tau in self.tau_stc], dtype=jnp.float64
).reshape(n_stc, *([1] * n_dims))
stc_elems_decayed = stc_elems * P_stc
else:
stc_total_pA = jnp.zeros(v_shape, dtype=jnp.float64)
stc_elems_decayed = None
if n_sfa > 0:
sfa_elems = self._sfa_elems_state.value # (n_sfa, *v_shape) float64
V_T_star_f64 = jnp.asarray(self._V_T_star_mV_np, dtype=jnp.float64)
sfa_total_mV = V_T_star_f64 + jnp.sum(sfa_elems, axis=0)
P_sfa = jnp.array(
[np.exp(-dt_ms / tau) for tau in self.tau_sfa], dtype=jnp.float64
).reshape(n_sfa, *([1] * n_dims))
sfa_elems_decayed = sfa_elems * P_sfa
else:
sfa_total_mV = (
jnp.asarray(self._V_T_star_mV_np, dtype=jnp.float64)
+ jnp.zeros(v_shape, dtype=jnp.float64)
)
sfa_elems_decayed = None
# Store totals for property access
self._stc_val_state.value = stc_total_pA
self._sfa_val_state.value = sfa_total_mV
# ---- Step 2: Decay synaptic currents ----
I_syn_ex_pA = I_syn_ex_pA * jnp.asarray(self._P11_ex, dtype=jnp.float64)
I_syn_in_pA = I_syn_in_pA * jnp.asarray(self._P11_in, dtype=jnp.float64)
# ---- Step 3: Add arriving spike weights ----
w_ex, w_in = self._sum_signed_delta_inputs()
w_ex_pA = jnp.asarray(u.get_mantissa(w_ex / u.pA), dtype=jnp.float64)
w_in_pA = jnp.asarray(u.get_mantissa(w_in / u.pA), dtype=jnp.float64)
I_syn_ex_pA = I_syn_ex_pA + w_ex_pA
I_syn_in_pA = I_syn_in_pA + w_in_pA
# ---- Step 4: RNG for stochastic spike check ----
new_rng, subkey = jax.random.split(self._rng_state.value)
self._rng_state.value = new_rng
rand_vals = jax.random.uniform(subkey, shape=v_shape)
# ---- Step 5: V update via exact propagator + stochastic spike ----
is_refractory = r > 0
# Pre-load float64 propagator constants
P33 = jnp.asarray(self._P33, dtype=jnp.float64)
P30 = jnp.asarray(self._P30, dtype=jnp.float64)
P31 = jnp.asarray(self._P31, dtype=jnp.float64)
P21_ex = jnp.asarray(self._P21_ex, dtype=jnp.float64)
P21_in = jnp.asarray(self._P21_in, dtype=jnp.float64)
E_L_f64 = jnp.asarray(self._E_L_mV_np, dtype=jnp.float64)
V_reset_f64 = jnp.asarray(self._V_reset_mV_np, dtype=jnp.float64)
I_e_f64 = jnp.asarray(self._I_e_pA_np, dtype=jnp.float64)
Delta_V_f64 = jnp.asarray(self._Delta_V_mV_np, dtype=jnp.float64)
# Exact propagator: V_new = P33*V + P30*(I_stim+I_e-stc) + P31*E_L + I_syn*P21
V_propagated = (
P33 * V_mV
+ P30 * (i_stim_pA + I_e_f64 - stc_total_pA)
+ P31 * E_L_f64
+ I_syn_ex_pA * P21_ex
+ I_syn_in_pA * P21_in
)
# Stochastic spike check for non-refractory neurons
exp_arg = jnp.clip((V_propagated - sfa_total_mV) / Delta_V_f64, -500.0, 500.0)
lam = jnp.float64(self.lambda_0) * jnp.exp(exp_arg) # 1/ms
spike_prob = jnp.clip(-jnp.expm1(-lam * jnp.float64(dt_ms)), 0.0, 1.0)
rand_f64 = rand_vals.astype(jnp.float64)
spike_mask = (~is_refractory) & (rand_f64 < spike_prob)
# V after step: propagated if non-refractory, V_reset if refractory
# Note: on spike step V = V_propagated (not V_reset), matching NEST reference.
V_mV = jnp.where(is_refractory, V_reset_f64, V_propagated)
# Update refractory counter
ref_count = jnp.asarray(self.ref_count, dtype=ditype)
new_r = jnp.where(
is_refractory,
r - 1,
jnp.where(spike_mask & (ref_count > 0), ref_count, r)
)
# ---- Step 6: stc/sfa jumps on spike ----
spike_mask_f64 = spike_mask.astype(jnp.float64)
if n_stc > 0:
q_stc_arr = jnp.array(self.q_stc, dtype=jnp.float64).reshape(
n_stc, *([1] * n_dims)
)
self._stc_elems_state.value = stc_elems_decayed + q_stc_arr * spike_mask_f64
if n_sfa > 0:
q_sfa_arr = jnp.array(self.q_sfa, dtype=jnp.float64).reshape(
n_sfa, *([1] * n_dims)
)
self._sfa_elems_state.value = sfa_elems_decayed + q_sfa_arr * spike_mask_f64
# ---- Step 7: Write back state ----
self.V.value = V_mV * u.mV
self.I_syn_ex.value = I_syn_ex_pA * u.pA
self.I_syn_in.value = I_syn_in_pA * u.pA
self.refractory_step_count.value = jnp.asarray(new_r, dtype=ditype)
self.I_stim.value = new_i_stim_pA * u.pA
last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(new_r > 0)
return u.math.asarray(spike_mask, dtype=dftype)