# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable, Optional, Sequence
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
from ._utils import is_tracer
__all__ = [
'pp_psc_delta',
]
class pp_psc_delta(NESTNeuron):
r"""Point process neuron with leaky integration of delta-shaped PSCs.
``pp_psc_delta`` is an implementation of a leaky integrator where the
potential jumps on each spike arrival. It produces spikes stochastically
according to a transfer function operating on the membrane potential, and
supports spike-frequency adaptation with multiple exponential kernels.
This is a brainpy.state re-implementation of the NEST simulator model of
the same name, using NEST-standard parameterization and exact integration.
Parameters
----------
in_size : int, tuple of int
Population shape. Defines the number of neurons in the population.
tau_m : Quantity, optional
Membrane time constant. Must be a positive quantity with time units.
Default: 10.0 ms.
C_m : Quantity, optional
Membrane capacitance. Must be a positive quantity with capacitance units.
Default: 250.0 pF.
dead_time : float, optional
Duration of the dead time (absolute refractory period) in milliseconds.
If set to 0, the model operates in Poisson mode with potentially multiple
spikes per time step. If ``dead_time`` is nonzero but smaller than the
simulation resolution, it is clamped to the resolution. Must be non-negative.
Default: 1.0 ms.
dead_time_random : bool, optional
Whether to draw random dead time after each spike from a gamma distribution.
If True, ``dead_time`` becomes the mean of the gamma distribution with
shape parameter ``dead_time_shape``. Default: False.
dead_time_shape : int, optional
Shape parameter of the gamma distribution for random dead times. Must be
at least 1. Default: 1.
with_reset : bool, optional
Whether to reset the membrane potential to 0 after each spike. Default: True.
tau_sfa : tuple of float, optional
Adaptive threshold time constants in milliseconds. Each element defines
the decay time constant of one adaptation kernel. Must be a sequence of
positive values with the same length as ``q_sfa``. Default: () (no adaptation).
q_sfa : tuple of float, optional
Adaptive threshold jump sizes in millivolts. Each element defines the
increment added to the corresponding adaptation kernel on each spike.
Must be a sequence with the same length as ``tau_sfa``. Default: () (no adaptation).
c_1 : float, optional
Slope of the linear part of the transfer function in Hz/mV. Default: 0.0.
c_2 : float, optional
Prefactor of the exponential part of the transfer function in Hz. Can be
used as an offset spike rate when ``c_3 = 0``. Default: 1.238 Hz.
c_3 : float, optional
Coefficient of exponential nonlinearity in 1/mV. Must be non-negative.
Set to 0 for purely linear transfer function. Default: 0.25 1/mV.
I_e : Quantity, optional
Constant external input current. Must be a quantity with current units.
Default: 0.0 pA.
t_ref_remaining : float, optional
Remaining dead time at simulation start in milliseconds. Must be non-negative.
Default: 0.0 ms.
rng_key : jax.Array, optional
JAX PRNG key for stochastic spike generation. If None, a default key is
used. For reproducible results, provide an explicit key. Default: None.
V_initializer : Callable, optional
Initializer for the membrane potential (relative to resting potential).
Default: ``Constant(0.0 * u.mV)``.
spk_fun : Callable, optional
Surrogate spike function for differentiable spike generation. Default:
``ReluGrad()``.
spk_reset : str, optional
Reset mode. Options: ``'hard'`` (stop gradient), ``'soft'`` (V -= V_th).
Default: ``'hard'`` (matches NEST behavior).
name : str, optional
Name of the neuron population. Default: None.
Raises
------
ValueError
If ``C_m <= 0`` (capacitance must be strictly positive).
ValueError
If ``tau_m <= 0`` (membrane time constant must be strictly positive).
ValueError
If ``dead_time < 0`` (dead time must be non-negative).
ValueError
If ``dead_time_shape < 1`` (gamma shape parameter must be at least 1).
ValueError
If ``t_ref_remaining < 0`` (remaining refractory time must be non-negative).
ValueError
If ``c_3 < 0`` (exponential coefficient must be non-negative).
ValueError
If any element of ``tau_sfa <= 0`` (adaptation time constants must be positive).
ValueError
If ``len(tau_sfa) != len(q_sfa)`` (adaptation parameter lists must match).
See Also
--------
iaf_psc_delta : Integrate-and-fire neuron with delta PSCs
gif_psc_exp : Generalized integrate-and-fire with exponential PSCs
Parameter Mapping
-----------------
========================= ====================== ===============================================
**NEST Parameter** **brainpy.state** **Notes**
========================= ====================== ===============================================
``tau_m`` ``tau_m`` Membrane time constant
``C_m`` ``C_m`` Membrane capacitance
``dead_time`` ``dead_time`` Refractory period duration
``dead_time_random`` ``dead_time_random`` Enable random dead time
``dead_time_shape`` ``dead_time_shape`` Gamma distribution shape parameter
``with_reset`` ``with_reset`` Reset ``V_m`` after spike
``tau_sfa`` ``tau_sfa`` Adaptation time constants (list)
``q_sfa`` ``q_sfa`` Adaptation jump sizes (list)
``c_1`` ``c_1`` Linear transfer function coefficient
``c_2`` ``c_2`` Exponential transfer function prefactor
``c_3`` ``c_3`` Exponential transfer function exponent
``I_e`` ``I_e`` External input current
``t_ref_remaining`` ``t_ref_remaining`` Initial refractory time
``V_m`` ``V.value`` Membrane potential (relative to rest)
``E_sfa`` ``_q_val`` Sum of all adaptation elements
========================= ====================== ===============================================
**1. Mathematical Model**
**1.1. Membrane Dynamics**
The membrane potential :math:`V_\mathrm{m}` (relative to resting potential)
evolves according to a leaky integrator:
.. math::
C_\mathrm{m} \frac{dV_\mathrm{m}}{dt} = -\frac{V_\mathrm{m}}{\tau_\mathrm{m}}
+ I_\mathrm{e} + I_\mathrm{syn}(t)
where:
- :math:`C_\mathrm{m}` is the membrane capacitance
- :math:`\tau_\mathrm{m}` is the membrane time constant
- :math:`I_\mathrm{e}` is the constant external input current
- :math:`I_\mathrm{syn}(t)` is the synaptic input current
The exact (analytic) integration over one time step :math:`h` gives:
.. math::
V_\mathrm{m}(t + h) = P_{33} \cdot V_\mathrm{m}(t)
+ P_{30} \cdot (I_0 + I_\mathrm{e})
+ w_\mathrm{syn}
where:
- :math:`P_{33} = \exp(-h / \tau_\mathrm{m})`
- :math:`P_{30} = \frac{\tau_\mathrm{m}}{C_\mathrm{m}}(1 - P_{33})`
- :math:`I_0` is the buffered current from the previous step (ring buffer)
- :math:`w_\mathrm{syn}` is the sum of all incoming delta-shaped PSP jumps (in mV)
**1.2. Transfer Function**
The instantaneous firing rate is computed from the effective membrane potential
:math:`V' = V_\mathrm{m} - E_\mathrm{sfa}` using a flexible transfer function:
.. math::
\text{rate}(t) = \text{Rect}\!\left[
c_1 \cdot V'(t) + c_2 \cdot \exp(c_3 \cdot V'(t))
\right]
where :math:`\text{Rect}(x) = \max(0, x)` ensures non-negative rates.
By adjusting ``c_1``, ``c_2``, and ``c_3``, the transfer function can be:
- Linear: Set ``c_3 = 0``, ``c_1 > 0`` -- :math:`\text{rate} = c_1 V' + c_2`
- Exponential: Set ``c_1 = 0`` -- :math:`\text{rate} = c_2 \exp(c_3 V')`
- Mixed: All coefficients nonzero -- linear + exponential
**1.3. Spike-Frequency Adaptation**
The adaptive threshold :math:`E_\mathrm{sfa}` is the sum of multiple exponential
kernels, each with its own time constant and jump size:
.. math::
\tau_{\mathrm{sfa},i} \frac{dE_{\mathrm{sfa},i}}{dt} = -E_{\mathrm{sfa},i}
.. math::
E_{\mathrm{sfa},i}(t) \to E_{\mathrm{sfa},i}(t) + q_{\mathrm{sfa},i}
\quad \text{(on spike)}
.. math::
E_\mathrm{sfa}(t) = \sum_{i=1}^{n} E_{\mathrm{sfa},i}(t)
The adaptation kernels decay exponentially with exact propagators:
.. math::
E_{\mathrm{sfa},i}(t + h) = E_{\mathrm{sfa},i}(t) \exp(-h / \tau_{\mathrm{sfa},i})
**1.4. Stochastic Spike Generation**
- With dead time (``dead_time > 0``): At most one spike per time step.
A uniform random number :math:`u \sim \mathcal{U}(0,1)` is compared to
the spike probability:
.. math::
P(\text{spike}) = 1 - \exp(-\text{rate} \cdot h \cdot 10^{-3})
A spike is generated if :math:`u \le P(\text{spike})`.
- Without dead time (``dead_time = 0``): Multiple spikes per step are
possible. The number of spikes is drawn from a Poisson distribution:
.. math::
n_{\text{spikes}} \sim \text{Poisson}(\text{rate} \cdot h \cdot 10^{-3})
The factor :math:`10^{-3}` converts from Hz*ms to a dimensionless rate.
**1.5. Dead Time (Refractory Period)**
After each spike, the neuron enters a dead time during which it cannot spike:
- Fixed dead time: ``dead_time_random = False``. The neuron is refractory
for exactly ``dead_time`` milliseconds, converted to grid steps.
- Random dead time: ``dead_time_random = True``. The dead time is drawn
from a gamma distribution with shape ``dead_time_shape`` and mean ``dead_time``.
If ``dead_time`` is nonzero but smaller than the simulation resolution :math:`h`,
it is clamped to :math:`h`.
**2. Numerical Integration and Update Order**
The discrete-time update per simulation step follows this order:
1. **Update membrane potential** via exact propagator (including external
current and synaptic delta inputs).
2. **Decay adaptation elements** and compute total :math:`E_\mathrm{sfa}`.
3. **Spike check**:
- If not refractory: compute effective potential
:math:`V' = V_\mathrm{m} - E_\mathrm{sfa}`,
compute instantaneous rate, draw random number and potentially emit spike(s).
If spike occurs:
- Jump all adaptation elements by ``q_sfa``
- Optionally reset :math:`V_\mathrm{m}` to 0 (if ``with_reset = True``)
- Set dead time counter
- If refractory: decrement dead time counter
4. **Buffer external current** for the next step (ring buffer semantics).
**3. Important Implementation Notes**
- Relative membrane potential: The membrane potential :math:`V_\mathrm{m}`
is stored relative to the resting potential (resting potential = 0 mV).
This differs from ``iaf_psc_delta``, which uses absolute potentials.
- Stochastic reproducibility: Because spiking is stochastic (random number
drawn each step), exact spike-time reproducibility requires matching the
random number generator state. For deterministic testing, set ``rng_key``
explicitly.
- Dead time < dt clamping: If ``dead_time`` is nonzero but smaller than
the simulation resolution, it is internally clamped to the resolution to
match NEST behavior.
- Poisson mode performance: For non-refractory neurons (``dead_time = 0``),
Poisson random draws are used, which are slower than uniform random draws.
For typical firing rates (<1 spike/time_step), setting a small ``dead_time``
(e.g., 1e-8 ms) is faster and nearly equivalent.
**4. State Variables**
============================== ================= ==========================================
**State Variable** **Type** **Description**
============================== ================= ==========================================
``V`` HiddenState Membrane potential (relative to rest)
``refractory_step_count`` ShortTermState Remaining dead time grid steps
``I_stim`` ShortTermState Buffered current applied in next step
``last_spike_time`` ShortTermState Last spike time (for recording)
``_q_elems`` HiddenState Adaptation kernel elements (internal)
``_q_val`` ShortTermState Total :math:`E_\mathrm{sfa}` (internal)
``_rng_state`` JAX PRNG key Random number generator state (internal)
============================== ================= ==========================================
- Default parameter values match NEST C++ source for ``pp_psc_delta``,
which are based on Jolivet et al. (2006) [2]_.
- ``tau_sfa`` and ``q_sfa`` default to empty tuples (no adaptation).
In NEST, the C++ defaults of ``tau_sfa=34.0`` and ``q_sfa=0.0`` are
immediately cleared in the constructor, resulting in empty vectors.
- The recordable ``V_m`` in NEST corresponds to ``self.V.value`` in brainpy.state.
- The recordable ``E_sfa`` in NEST corresponds to ``self._q_val`` (the sum of
all adaptation elements).
References
----------
.. [1] Cardanobile S, Rotter S (2010). Multiplicatively interacting point
processes and applications to neural modeling. Journal of
Computational Neuroscience 28(2):267-284.
DOI: https://doi.org/10.1007/s10827-009-0204-0
.. [2] Jolivet R, Rauch A, Luescher H-R, Gerstner W (2006). Predicting
spike timing of neocortical pyramidal neurons by simple threshold
models. Journal of Computational Neuroscience 21:35-49.
DOI: https://doi.org/10.1007/s10827-006-7074-5
.. [3] Pozzorini C, Naud R, Mensi S, Gerstner W (2013). Temporal whitening
by power-law adaptation in neocortical neurons. Nature Neuroscience
16:942-948.
DOI: https://doi.org/10.1038/nn.3431
.. [4] Grytskyy D, Tetzlaff T, Diesmann M, Helias M (2013). A unified view
on weakly correlated recurrent networks. Frontiers in Computational
Neuroscience, 7:131.
DOI: https://doi.org/10.3389/fncom.2013.00131
.. [5] Deger M, Schwalger T, Naud R, Gerstner W (2014). Fluctuations and
information filtering in coupled populations of spiking neurons with
adaptation. Physical Review E 90:6, 062704.
DOI: https://doi.org/10.1103/PhysRevE.90.062704
.. [6] Gerstner W, Kistler WM, Naud R, Paninski L (2014). Neuronal
Dynamics: From single neurons to networks and models of cognition.
Cambridge University Press.
.. [7] NEST Simulator ``pp_psc_delta`` model documentation and C++ source:
``models/pp_psc_delta.h`` and ``models/pp_psc_delta.cpp``.
Examples
--------
Basic usage with default parameters:
.. code-block:: python
>>> import brainpy.state as bst
>>> import saiunit as u
>>> neurons = bst.pp_psc_delta(100)
>>> neurons.init_all_states()
Exponential transfer function (default):
.. code-block:: python
>>> neurons = bst.pp_psc_delta(
... 100,
... c_1=0.0, # no linear part
... c_2=1.238, # exponential prefactor
... c_3=0.25 # exponential coefficient
... )
Linear transfer function with offset:
.. code-block:: python
>>> neurons = bst.pp_psc_delta(
... 100,
... c_1=10.0, # linear slope (Hz/mV)
... c_2=5.0, # offset rate (Hz)
... c_3=0.0 # disable exponential
... )
With spike-frequency adaptation:
.. code-block:: python
>>> neurons = bst.pp_psc_delta(
... 100,
... tau_sfa=(100.0, 1000.0), # two adaptation kernels
... q_sfa=(5.0, 10.0) # jump sizes in mV
... )
Poisson mode (no dead time):
.. code-block:: python
>>> neurons = bst.pp_psc_delta(
... 100,
... dead_time=0.0 # multiple spikes per step possible
... )
Random dead time:
.. code-block:: python
>>> neurons = bst.pp_psc_delta(
... 100,
... dead_time=2.0, # mean dead time (ms)
... dead_time_random=True, # enable random dead time
... dead_time_shape=2 # gamma distribution shape
... )
Reproducible stochastic behavior:
.. code-block:: python
>>> import jax
>>> key = jax.random.PRNGKey(42)
>>> neurons = bst.pp_psc_delta(100, rng_key=key)
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
tau_m: ArrayLike = 10.0 * u.ms,
C_m: ArrayLike = 250.0 * u.pF,
dead_time: float = 1.0, # ms, plain float as in NEST
dead_time_random: bool = False,
dead_time_shape: int = 1,
with_reset: bool = True,
tau_sfa: Sequence[float] = (), # ms values
q_sfa: Sequence[float] = (), # mV values
c_1: float = 0.0, # Hz/mV
c_2: float = 1.238, # Hz
c_3: float = 0.25, # 1/mV
I_e: ArrayLike = 0.0 * u.pA,
t_ref_remaining: float = 0.0, # ms
rng_key: Optional[jax.Array] = None,
V_initializer: Callable = braintools.init.Constant(0.0 * u.mV),
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
# Membrane parameters
self.tau_m = braintools.init.param(tau_m, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
# Dead time parameters (stored as plain Python scalars)
self.dead_time = float(dead_time)
self.dead_time_random = bool(dead_time_random)
self.dead_time_shape = int(dead_time_shape)
self.with_reset = bool(with_reset)
# Transfer function coefficients
self.c_1 = float(c_1)
self.c_2 = float(c_2)
self.c_3 = float(c_3)
# Initial dead time remaining
self.t_ref_remaining = float(t_ref_remaining)
# Adaptation parameters (stored as plain Python tuples of floats)
self.tau_sfa = tuple(float(x) for x in tau_sfa)
self.q_sfa = tuple(float(x) for x in q_sfa)
if len(self.tau_sfa) != len(self.q_sfa):
raise ValueError(
f"'tau_sfa' and 'q_sfa' must have the same length. "
f"Got {len(self.tau_sfa)} and {len(self.q_sfa)}."
)
# RNG key for stochastic spiking
self._rng_key = rng_key
# Initializers
self.V_initializer = V_initializer
self._validate_parameters()
def _validate_parameters(self):
r"""Validate all model parameters.
Raises
------
ValueError
If any parameter is outside its valid range.
Notes
-----
Validation checks:
- ``C_m > 0`` (capacitance must be positive)
- ``tau_m > 0`` (membrane time constant must be positive)
- ``dead_time >= 0`` (dead time must be non-negative)
- ``dead_time_shape >= 1`` (gamma shape must be at least 1)
- ``t_ref_remaining >= 0`` (remaining refractory time must be non-negative)
- ``c_3 >= 0`` (exponential coefficient must be non-negative)
- All elements of ``tau_sfa > 0`` (adaptation time constants must be positive)
- ``len(tau_sfa) == len(q_sfa)`` (adaptation parameter lists must match)
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.C_m, self.tau_m)):
return
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.tau_m <= 0.0 * u.ms):
raise ValueError('Membrane time constant must be strictly positive.')
if self.dead_time < 0.0:
raise ValueError('Dead time must not be negative.')
if self.dead_time_shape < 1:
raise ValueError('Shape of the dead time gamma distribution must not be smaller than 1.')
if self.t_ref_remaining < 0.0:
raise ValueError('Remaining refractory time must not be negative.')
if self.c_3 < 0.0:
raise ValueError('c_3 must not be negative.')
for tau in self.tau_sfa:
if tau <= 0.0:
raise ValueError('All SFA time constants must be strictly positive.')
def _precompute_constants(self, state_shape):
r"""Pre-compute dt-dependent propagator constants for JIT compatibility.
Called from :meth:`init_state` while ``dt`` is still a concrete Python
value (not a JAX abstract tracer). Storing these as plain Python floats
or static JAX arrays avoids ``ConcretizationTypeError`` inside
``jax.lax.scan`` / ``brainstate.transform.for_loop``.
Parameters
----------
state_shape : tuple
Shape of the membrane-potential state array (``V.shape`` after
initialization). Used to pre-shape the adaptation-decay and
q_sfa-jump arrays for broadcasting against ``(n_sfa, *state_shape)``.
"""
dt_q = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
# ---- h in ms as a concrete Python float ----
# brainstate.environ.get_dt() returns a Python saiunit.Quantity, so
# dividing by u.ms yields a plain Python float — no JAX array created.
self._h_ms = float(dt_q / u.ms)
# ---- Membrane propagator (computed once, reused every step) ----
self._P33 = u.math.exp(-dt_q / self.tau_m)
self._P30 = (1.0 / self.C_m) * (1.0 - self._P33) * self.tau_m
# ---- Effective dead time (clamped to dt if nonzero but smaller) ----
dead_time = self.dead_time
if dead_time != 0.0 and dead_time < self._h_ms:
dead_time = self._h_ms
self._dead_time_eff = dead_time # Python float constant
# ---- Dead time in grid steps (Python int constant) ----
if self._dead_time_eff > 0.0:
self._dead_time_counts = int(round(self._dead_time_eff / self._h_ms))
else:
self._dead_time_counts = 0
# ---- Adaptation decay factors (pre-shaped for broadcasting) ----
n_sfa = len(self.tau_sfa)
if n_sfa > 0:
P_sfa_1d = jnp.array(
[np.exp(-self._h_ms / tau) for tau in self.tau_sfa], dtype=dftype
)
# Reshape to (n_sfa,) + (1,) * len(state_shape) so it broadcasts
# against q_elems of shape (n_sfa, *state_shape).
P_sfa = P_sfa_1d
for _ in range(len(state_shape)):
P_sfa = jnp.expand_dims(P_sfa, axis=-1)
self._P_sfa = P_sfa
else:
self._P_sfa = None
# ---- q_sfa jump array (pre-shaped for broadcasting) ----
if n_sfa > 0:
q_sfa_arr = jnp.array(self.q_sfa, dtype=dftype)
for _ in range(len(state_shape)):
q_sfa_arr = jnp.expand_dims(q_sfa_arr, axis=-1)
self._q_sfa_arr = q_sfa_arr
else:
self._q_sfa_arr = None
[docs]
def init_state(self, batch_size=None, **kwargs):
r"""Initialize all state variables.
Allocates and initializes membrane potential, spike times, refractory
counters, buffered currents, adaptation kernels, and random number
generator state.
Parameters
----------
batch_size : int or None, optional
If provided, states are created with shape ``(batch_size, *varshape)``
to support batched simulation. If None, states have shape ``varshape``.
**kwargs : dict, optional
Additional keyword arguments (ignored).
Notes
-----
- Membrane potential is initialized using ``V_initializer``.
- Last spike time is initialized to -1e7 ms (sufficiently in the past).
- Refractory counter is initialized based on ``t_ref_remaining``.
- Adaptation kernels (``_q_elems``) are initialized to zero.
- Random number generator state is initialized from ``rng_key`` or
a default key.
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
V = braintools.init.param(self.V_initializer, self.varshape, batch_size)
state_shape = V.shape
self.V = brainstate.HiddenState(V)
self.last_spike_time = brainstate.ShortTermState(u.math.full(state_shape, -1e7 * u.ms))
self.refractory_step_count = brainstate.ShortTermState(u.math.full(state_shape, 0, dtype=ditype))
self.I_stim = brainstate.ShortTermState(u.math.full(state_shape, 0.0 * u.pA, dtype=dftype))
# Adaptation state: q_elems array stored as JAX arrays (mV units)
n_sfa = len(self.tau_sfa)
if n_sfa > 0:
self._q_elems = brainstate.HiddenState(
u.math.zeros((n_sfa, *state_shape), dtype=dftype) * u.mV
)
else:
self._q_elems = None
self._q_val = brainstate.ShortTermState(
u.math.zeros(state_shape, dtype=dftype) * u.mV
)
# Pre-compute dt-dependent propagator constants (must happen after
# state_shape is known, and while dt is still a concrete Python value).
self._precompute_constants(state_shape)
# Initialize remaining dead time from parameter (uses _h_ms from above)
if self.t_ref_remaining > 0.0:
r_init = int(round(self.t_ref_remaining / self._h_ms))
self.refractory_step_count.value = u.math.full(state_shape, r_init, dtype=ditype)
# RNG state wrapped in ShortTermState so for_loop carries it correctly
if self._rng_key is not None:
rng = self._rng_key
else:
rng = jax.random.PRNGKey(0)
self._rng_state = brainstate.ShortTermState(rng)
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Compute surrogate gradient spike output for backpropagation.
This method is used for computing differentiable spike outputs during
training. For a stochastic point process neuron, the true spike output
is random and computed in ``update()``. This method provides a surrogate
gradient based on the membrane potential.
Parameters
----------
V : ArrayLike, optional
Membrane potential (with units). If None, uses the current state
``self.V.value``. Default: None.
Returns
-------
spike : jax.Array
Differentiable spike output. Shape matches ``V``.
Notes
-----
- This method is primarily used for gradient-based optimization.
- The surrogate gradient is computed by scaling the membrane potential
and passing it through ``spk_fun`` (e.g., ``ReluGrad``).
- The true stochastic spike output is computed in ``update()`` and is
not directly differentiable.
"""
V = self.V.value if V is None else V
# For a stochastic model, we use V directly scaled by a reasonable factor
v_scaled = V / (1.0 * u.mV)
return self.spk_fun(v_scaled)
[docs]
def update(self, x=0.0 * u.pA):
r"""Update neuron state for one simulation step.
Performs the complete update sequence: (1) updates membrane potential
via exact propagator, (2) decays adaptation kernels, (3) computes
instantaneous firing rate and stochastically generates spikes, (4) buffers
input current for the next step.
Parameters
----------
x : Quantity, optional
External current input (with current units). This input is added to
the sum of all registered current inputs via projections. Default: 0.0 pA.
Returns
-------
spike : jax.Array
Binary spike output array. Shape: ``in_size``. Values are 1.0 where
spikes occurred, 0.0 otherwise. In Poisson mode (``dead_time = 0``),
values can be integers > 1 representing multiple spikes per step.
Notes
-----
**Update order per time step:**
1. **Membrane potential update**: Apply exact propagator to update
:math:`V_\mathrm{m}` using buffered current from the previous step,
constant external current, and delta-shaped synaptic inputs.
2. **Adaptation decay**: Decay all adaptation kernel elements using
exponential propagators. Compute total :math:`E_\mathrm{sfa}`.
3. **Spike generation**:
- If not refractory: compute effective potential :math:`V' = V_\mathrm{m} - E_\mathrm{sfa}`,
compute instantaneous rate from transfer function, draw random
number(s), and potentially emit spike(s).
- If spike occurs: jump adaptation elements by ``q_sfa``, optionally
reset :math:`V_\mathrm{m}` to 0, set dead time counter.
- If refractory: decrement dead time counter.
4. **Buffer input**: Store external current input for the next step
(ring buffer semantics, matching NEST).
**Spike generation modes:**
- With dead time (``dead_time > 0``): At most one spike per step.
Uses uniform random numbers and spike probability.
- Without dead time (``dead_time = 0``): Poisson-distributed spikes.
Multiple spikes per step are possible.
**Failure modes:**
- If ``C_m`` or ``tau_m`` contain invalid values (NaN, Inf), membrane
potential update will fail silently (produces NaN).
- If ``c_3 * V'`` causes overflow in ``exp()``, the exponential term
will saturate to infinity. The rectifier ensures the rate remains
non-negative.
- If random number generator state is corrupted, spike generation will
produce undefined results.
**Performance considerations:**
- Poisson mode (``dead_time = 0``) is slower due to Poisson random draws.
- Setting a small ``dead_time`` (e.g., 1e-8 ms) uses faster uniform
random numbers and is nearly equivalent for typical firing rates.
- Random dead time (``dead_time_random = True``) requires additional
gamma distribution samples per spike.
"""
t = brainstate.environ.get('t')
dt_q = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
# Use pre-computed h_ms (Python float, safe inside JIT/scan).
# _precompute_constants() stored this in init_state() when dt was concrete.
h_ms = self._h_ms
# Read state variables
V = self.V.value # mV
r = self.refractory_step_count.value # int
i_stim = self.I_stim.value # pA
state_shape = V.shape # (batch_size, *varshape) or varshape
# ---- Step 1: Update membrane potential via exact propagator ----
# Use pre-computed propagator coefficients (P33, P30 from init_state).
delta_v = self.sum_delta_inputs(u.math.zeros(self.varshape) * u.mV)
V = self._P30 * (i_stim + self.I_e) + self._P33 * V + delta_v
# ---- Step 2: Decay adaptation elements and compute total E_sfa ----
n_sfa = len(self.tau_sfa)
if n_sfa > 0 and self._q_elems is not None:
q_elems = self._q_elems.value # (n_sfa, *state_shape) in mV
# Use pre-shaped P_sfa (shape: (n_sfa,) + (1,)*len(state_shape))
q_elems = q_elems * self._P_sfa
q_total = u.math.sum(q_elems, axis=0) # shape: state_shape, in mV
else:
q_elems = None
q_total = u.math.zeros(self.varshape) * u.mV
# ---- Step 3: Spike check / refractory ----
not_refractory = r == 0
# Compute effective potential and transfer function rate
V_eff = V - q_total # mV
V_eff_raw = V_eff / u.mV # unitless
# Transfer function: rate = rect(c_1 * V_eff + c_2 * exp(c_3 * V_eff))
# Clip c_3 * V_eff to prevent overflow
exp_arg = jnp.clip(self.c_3 * V_eff_raw, -500.0, 500.0)
rate = self.c_1 * V_eff_raw + self.c_2 * jnp.exp(exp_arg)
rate = jnp.maximum(rate, 0.0) # rectifier
# Advance RNG state for this step
rng_state, subkey = jax.random.split(self._rng_state.value)
self._rng_state.value = rng_state
# Use pre-computed effective dead time and grid-step count.
dead_time = self._dead_time_eff # Python float constant
if dead_time > 0.0:
# With dead time: at most 1 spike per step
# spike_prob = 1 - exp(-rate * h * 1e-3) = -expm1(-rate * h * 1e-3)
spike_prob = -jnp.expm1(-rate * h_ms * 1e-3)
rand_vals = jax.random.uniform(subkey, shape=state_shape, dtype=dftype)
spike_now = not_refractory & (rate > 0.0) & (rand_vals <= spike_prob)
# Set dead time counter
if self.dead_time_random:
# Gamma-distributed dead time
_, gamma_key = jax.random.split(subkey)
gamma_samples = jax.random.gamma(
gamma_key, self.dead_time_shape, shape=state_shape, dtype=dftype
)
dt_rate = self.dead_time_shape / dead_time
new_r_random = jnp.maximum(1, jnp.round(gamma_samples / dt_rate / h_ms).astype(ditype))
new_r = jnp.where(spike_now, new_r_random, r)
else:
new_r = jnp.where(spike_now, self._dead_time_counts, r)
n_spikes = jnp.where(spike_now, 1, 0).astype(ditype)
else:
# Without dead time (Poisson mode): multiple spikes per step possible
lam_poisson = rate * h_ms * 1e-3
n_spikes_raw = jax.random.poisson(subkey, lam_poisson, shape=state_shape, dtype=ditype)
n_spikes = jnp.where(not_refractory & (rate > 0.0), n_spikes_raw, 0)
spike_now = n_spikes > 0
new_r = r # no dead time to set
# Decrement refractory counter for neurons that did NOT spike
# (neurons still in refractory period)
refractory_and_no_spike = (r > 0) & ~spike_now
new_r = jnp.where(refractory_and_no_spike, r - 1, new_r)
# Jump adaptation elements on spike (use pre-shaped q_sfa_arr)
if n_sfa > 0 and q_elems is not None:
n_spikes_float = jnp.expand_dims(n_spikes.astype(dftype), axis=0)
q_elems = q_elems + (self._q_sfa_arr * n_spikes_float) * u.mV
q_total = u.math.sum(q_elems, axis=0)
# Reset membrane potential if applicable
if self.with_reset:
V = u.math.where(spike_now, 0.0 * u.mV, V)
# ---- Step 4: Get external current for NEXT step (NEST ring buffer semantics) ----
new_i_stim = self.sum_current_inputs(x, self.V.value)
# ---- Write back state ----
self.V.value = V
self.refractory_step_count.value = jnp.asarray(u.get_mantissa(new_r), dtype=ditype)
self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA
last_spike_time = u.math.where(spike_now, t + dt_q, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
if n_sfa > 0 and self._q_elems is not None:
self._q_elems.value = q_elems
self._q_val.value = q_total
spike_mask = spike_now if dead_time > 0.0 else (n_spikes > 0)
return u.math.asarray(spike_mask, dtype=dftype)