# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable, Iterable
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
from ._utils import is_tracer
from .iaf_psc_alpha import iaf_psc_alpha
__all__ = [
'iaf_psc_alpha_ps',
]
class iaf_psc_alpha_ps(NESTNeuron):
r"""NEST-compatible ``iaf_psc_alpha_ps`` with precise spike timing.
Description
-----------
``iaf_psc_alpha_ps`` is a current-based leaky integrate-and-fire neuron
with alpha-shaped excitatory/inhibitory postsynaptic currents (PSCs),
fixed absolute refractoriness, and off-grid spike/event timing. The
implementation matches NEST ``models/iaf_psc_alpha_ps.{h,cpp}`` semantics:
event-driven mini-step splitting inside each global ``dt`` interval,
exact linear propagators for alpha states, and bisection-based sub-step
threshold-time localization.
**1. Continuous-Time Model and Alpha Current State-Space**
Define :math:`U = V_m - E_L` and :math:`I_\mathrm{syn}=I_\mathrm{ex}+I_\mathrm{in}`.
Subthreshold dynamics are
.. math::
\frac{dU}{dt} = -\frac{U}{\tau_m} + \frac{I_\mathrm{syn} + I_e + y_\mathrm{in}}{C_m}.
For each channel :math:`X\in\{\mathrm{ex},\mathrm{in}\}`, alpha PSCs use
a two-state system:
.. math::
\frac{d\,dI_X}{dt} = -\frac{dI_X}{\tau_{\mathrm{syn},X}}, \qquad
\frac{dI_X}{dt} = dI_X - \frac{I_X}{\tau_{\mathrm{syn},X}}.
This realizes normalized kernel
.. math::
i_X(t) = \frac{e}{\tau_{\mathrm{syn},X}} t e^{-t/\tau_{\mathrm{syn},X}} \Theta(t),
so a spike weight :math:`w` (pA) is injected into derivative states as
:math:`dI_\mathrm{ex}\leftarrow dI_\mathrm{ex}+\frac{e}{\tau_{\mathrm{syn,ex}}}w`
for :math:`w\ge 0` and
:math:`dI_\mathrm{in}\leftarrow dI_\mathrm{in}+\frac{e}{\tau_{\mathrm{syn,in}}}w`
for :math:`w<0` (inhibitory channel stays negative by sign convention).
**2. Exact Mini-Step Propagation and Precise Threshold Crossing**
For each local interval :math:`\Delta t` between two ordered event offsets,
the code uses exact closed-form updates:
.. math::
dI_X(t+\Delta t) = e^{-\Delta t/\tau_{\mathrm{syn},X}} dI_X(t),
.. math::
I_X(t+\Delta t) = e^{-\Delta t/\tau_{\mathrm{syn},X}}
\big(I_X(t) + \Delta t\, dI_X(t)\big),
.. math::
U(t+\Delta t) = U(t) + \left(e^{-\Delta t/\tau_m}-1\right)U(t)
+ P_{30}(I_e+y_\mathrm{in})
+ \sum_X \left(P_{31,X} dI_X(t) + P_{32,X} I_X(t)\right),
with :math:`P_{30}=\tau_m(1-e^{-\Delta t/\tau_m})/C_m` and
:math:`P_{31,X}, P_{32,X}` evaluated by
:meth:`iaf_psc_alpha._alpha_propagator_p31_p32` (including stable handling
near :math:`\tau_m\approx\tau_{\mathrm{syn},X}`).
If :math:`U` crosses :math:`U_{th}=V_{th}-E_L` inside a mini-step, the
crossing time solves :math:`f(\delta)=U(\delta)-U_{th}=0` using bounded
bisection (64 iterations), producing off-grid spike offset
``spike_off = dt - (local_time + delta)``.
**3. Event Ordering, Refractory Pseudo-Event, and Timing Convention**
Off-grid events are sorted by ``offset`` in descending order, where
``offset`` is measured from the right boundary of the current step
(:math:`0` at step end, :math:`dt` at step start). Each neuron can also
insert a refractory-release pseudo-event at stored ``last_spike_offset``
when ``step_idx + 1 - last_spike_step == ceil(t_ref / dt)``.
On spike emission:
- membrane state is reset to ``V_reset``,
- refractory flag is set,
- ``last_spike_step``, ``last_spike_offset``, ``last_spike_time`` are
updated with precise sub-step timing.
**4. Assumptions, Constraints, and Computational Implications**
- Construction constraints enforce ``C_m > 0``, ``tau_m > 0``,
``tau_syn_ex > 0``, ``tau_syn_in > 0``, and ``V_reset < V_th``.
- If ``V_min`` is set, ``V_reset >= V_min`` is required.
- Runtime requires ``ceil(t_ref / dt) >= 1``; otherwise update fails.
- ``x`` is ring-buffered current input: values supplied at step ``n`` are
consumed as ``y_input`` in step ``n+1``.
- Update is vectorized over ``self.varshape`` using array operations.
With ``K`` within-step events, cost is
:math:`O(|\mathrm{varshape}| \cdot K)`, plus root-search work when
threshold is crossed.
Parameters
----------
in_size : Size
Population shape specification. All model parameters are broadcast to
``self.varshape`` derived from ``in_size``.
E_L : ArrayLike, optional
Resting potential :math:`E_L` in mV. Scalar or array-like broadcastable
to ``self.varshape``. Default is ``-70. * u.mV``.
C_m : ArrayLike, optional
Membrane capacitance :math:`C_m` in pF. Must be strictly positive after
broadcasting to ``self.varshape``. Default is ``250. * u.pF``.
tau_m : ArrayLike, optional
Membrane time constant :math:`\tau_m` in ms. Must be strictly positive.
Default is ``10. * u.ms``.
t_ref : ArrayLike, optional
Absolute refractory time :math:`t_{ref}` in ms. Converted at runtime to
grid steps via ``ceil(t_ref / dt)``. Must yield at least one step.
Default is ``2. * u.ms``.
V_th : ArrayLike, optional
Spike threshold :math:`V_{th}` in mV, broadcastable to ``self.varshape``.
Default is ``-55. * u.mV``.
V_reset : ArrayLike, optional
Post-spike reset potential :math:`V_{reset}` in mV. Must satisfy
``V_reset < V_th`` elementwise. Default is ``-70. * u.mV``.
tau_syn_ex : ArrayLike, optional
Excitatory alpha time constant :math:`\tau_{\mathrm{syn,ex}}` in ms.
Strictly positive. Default is ``2. * u.ms``.
tau_syn_in : ArrayLike, optional
Inhibitory alpha time constant :math:`\tau_{\mathrm{syn,in}}` in ms.
Strictly positive. Default is ``2. * u.ms``.
I_e : ArrayLike, optional
Constant external current :math:`I_e` in pA, broadcastable to
``self.varshape``. Added in each mini-step membrane update.
Default is ``0. * u.pA``.
V_min : ArrayLike or None, optional
Optional lower voltage clamp :math:`V_{min}` in mV. When provided,
membrane candidates are clipped by ``max(V, V_min)`` before threshold
tests. ``None`` disables clipping. Default is ``None``.
V_initializer : Callable, optional
Initializer for membrane state ``V`` used in :meth:`init_state`.
Must return values unit-compatible with mV. Default is
``braintools.init.Constant(-70. * u.mV)``.
spk_fun : Callable, optional
Surrogate spike function used by :meth:`get_spike` and returned by
:meth:`update`. It receives normalized threshold distance and returns a
spike-like array broadcastable to neuron shape.
Default is ``braintools.surrogate.ReluGrad()``.
spk_reset : str, optional
Reset policy passed to :class:`~brainpy_state._base.Neuron`.
``'hard'`` matches NEST hard-reset behavior. Default is ``'hard'``.
ref_var : bool, optional
If ``True``, creates exposed state ``self.refractory`` mirroring
``self.is_refractory`` for inspection. Default is ``False``.
name : str or None, optional
Optional node name.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to model symbols
:header-rows: 1
:widths: 17 28 14 16 35
* - Parameter
- Type / shape / unit
- Default
- Math symbol
- Semantics
* - ``in_size``
- :class:`~brainstate.typing.Size`; scalar/tuple
- required
- --
- Defines population shape ``self.varshape``.
* - ``E_L``
- ArrayLike, broadcastable to ``self.varshape`` (mV)
- ``-70. * u.mV``
- :math:`E_L`
- Resting potential; membrane offset origin.
* - ``C_m``
- ArrayLike, broadcastable (pF), ``> 0``
- ``250. * u.pF``
- :math:`C_m`
- Membrane capacitance in all propagators.
* - ``tau_m``
- ArrayLike, broadcastable (ms), ``> 0``
- ``10. * u.ms``
- :math:`\tau_m`
- Membrane leak time constant.
* - ``t_ref``
- ArrayLike, broadcastable (ms), runtime ``ceil(t_ref/dt) >= 1``
- ``2. * u.ms``
- :math:`t_{ref}`
- Absolute refractory duration.
* - ``V_th`` and ``V_reset``
- ArrayLike, broadcastable (mV), with ``V_reset < V_th``
- ``-55. * u.mV``, ``-70. * u.mV``
- :math:`V_{th}`, :math:`V_{reset}`
- Threshold and reset levels.
* - ``tau_syn_ex`` and ``tau_syn_in``
- ArrayLike, broadcastable (ms), each ``> 0``
- ``2. * u.ms``
- :math:`\tau_{\mathrm{syn,ex}}`, :math:`\tau_{\mathrm{syn,in}}`
- Alpha PSC decay constants.
* - ``I_e``
- ArrayLike, broadcastable (pA)
- ``0. * u.pA``
- :math:`I_e`
- Constant injected current.
* - ``V_min``
- ArrayLike broadcastable (mV) or ``None``
- ``None``
- :math:`V_{min}`
- Optional lower membrane bound.
* - ``V_initializer``
- Callable returning mV-compatible values
- ``Constant(-70. * u.mV)``
- --
- Initial membrane state initializer.
* - ``spk_fun``
- Callable
- ``ReluGrad()``
- --
- Surrogate spike output nonlinearity.
* - ``spk_reset``
- str
- ``'hard'``
- --
- Reset mode inherited from base ``Neuron``.
* - ``ref_var``
- bool
- ``False``
- --
- Allocate exposed ``refractory`` mirror state.
* - ``name``
- str | None
- ``None``
- --
- Optional node name.
Raises
------
ValueError
If parameter constraints are violated (for example ``C_m <= 0``,
``tau_m <= 0``, ``tau_syn_ex <= 0``, ``tau_syn_in <= 0``,
``V_reset >= V_th``, ``V_reset < V_min``), if refractory duration in
steps is below one, or if any ``spike_events`` offset is outside
``[0, dt]``.
TypeError
If supplied quantities are not unit-compatible with expected units
(mV, ms, pA, pF) during conversion.
KeyError
If simulation context keys such as ``t`` or ``dt`` are missing when
:meth:`update` is called.
AttributeError
If :meth:`update` is called before :meth:`init_state` creates required
states (for example ``V`` or synaptic buffers).
Notes
-----
- ``spike_events`` accepts ``(offset, weight)`` tuples or
``{'offset': ..., 'weight': ...}`` dicts. Offsets are in ms and measured
from the right step boundary (NEST convention).
- Positive event weights update the excitatory derivative state; negative
event weights update inhibitory derivative state.
- The implementation computes all internal propagators in ``float64`` NumPy
space and writes back BrainUnit states afterward.
- ``last_spike_time`` stores precise absolute spike time in ms and is
stop-gradient wrapped.
Examples
--------
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neu = brainpy.state.iaf_psc_alpha_ps(in_size=(2,), I_e=220.0 * u.pA)
... neu.init_state()
... with brainstate.environ.context(t=1.0 * u.ms):
... spk = neu.update()
... _ = spk.shape
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neu = brainpy.state.iaf_psc_alpha_ps(in_size=1)
... neu.init_state()
... ev = [{'offset': 0.08 * u.ms, 'weight': 120.0 * u.pA}]
... with brainstate.environ.context(t=0.0 * u.ms):
... _ = neu.update(spike_events=ev)
References
----------
.. [1] NEST source: ``models/iaf_psc_alpha_ps.h`` and
``models/iaf_psc_alpha_ps.cpp``.
.. [2] Rotter S, Diesmann M (1999). Exact simulation of time-invariant linear
systems with applications to neuronal modeling. Biological Cybernetics
81:381-402. DOI: https://doi.org/10.1007/s004220050570
.. [3] Morrison A, Straube S, Plesser HE, Diesmann M (2007). Exact
subthreshold integration with continuous spike times in discrete time
neural network simulations. Neural Computation 19(1):47-79.
DOI: https://doi.org/10.1162/neco.2007.19.1.47
.. [4] Hanuschkin A, Kunkel S, Helias M, Morrison A, Diesmann M (2010).
A general and efficient method for incorporating exact spike times in
globally time-driven simulations. Frontiers in Neuroinformatics.
DOI: https://doi.org/10.3389/fninf.2010.00113
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
E_L: ArrayLike = -70. * u.mV,
C_m: ArrayLike = 250. * u.pF,
tau_m: ArrayLike = 10. * u.ms,
t_ref: ArrayLike = 2. * u.ms,
V_th: ArrayLike = -55. * u.mV,
V_reset: ArrayLike = -70. * u.mV,
tau_syn_ex: ArrayLike = 2. * u.ms,
tau_syn_in: ArrayLike = 2. * u.ms,
I_e: ArrayLike = 0. * u.pA,
V_min: ArrayLike = None,
V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
ref_var: bool = False,
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
self.E_L = braintools.init.param(E_L, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.tau_m = braintools.init.param(tau_m, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.V_th = braintools.init.param(V_th, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.tau_syn_ex = braintools.init.param(tau_syn_ex, self.varshape)
self.tau_syn_in = braintools.init.param(tau_syn_in, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
self.V_min = None if V_min is None else braintools.init.param(V_min, self.varshape)
self.V_initializer = V_initializer
self.ref_var = ref_var
self._validate_parameters()
def _validate_parameters(self):
r"""Validate model parameters against NEST constraints.
Raises
------
ValueError
If parameter inequalities or positivity constraints are violated.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.V_reset, self.C_m, self.tau_m)):
return
if np.any(self.V_reset >= self.V_th):
raise ValueError('Reset potential must be smaller than threshold.')
if self.V_min is not None and np.any(self.V_reset < self.V_min):
raise ValueError('Reset potential must be greater equal minimum potential.')
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.tau_m <= 0.0 * u.ms):
raise ValueError('Membrane time constant must be strictly positive.')
if np.any(self.tau_syn_ex <= 0.0 * u.ms) or np.any(self.tau_syn_in <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
[docs]
def init_state(self, **kwargs):
r"""Initialize persistent and short-term state variables.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
Raises
------
ValueError
If an initializer cannot be broadcast to requested shape.
TypeError
If initializer outputs have incompatible units/dtypes for the
corresponding state variables.
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
V = braintools.init.param(self.V_initializer, self.varshape)
zeros = u.math.zeros(self.varshape, dtype=V.dtype)
self.V = brainstate.HiddenState(V)
self.I_syn_ex = brainstate.ShortTermState(zeros * u.pA)
self.dI_syn_ex = brainstate.ShortTermState(u.math.zeros(self.varshape, dtype=dftype))
self.I_syn_in = brainstate.ShortTermState(zeros * u.pA)
self.dI_syn_in = brainstate.ShortTermState(u.math.zeros(self.varshape, dtype=dftype))
self.y_input = brainstate.ShortTermState(zeros * u.pA)
self.is_refractory = brainstate.ShortTermState(
braintools.init.param(braintools.init.Constant(False), self.varshape)
)
self.last_spike_step = brainstate.ShortTermState(
u.math.full(self.varshape, -1, dtype=ditype)
)
self.last_spike_offset = brainstate.ShortTermState(
u.math.zeros(self.varshape, dtype=dftype) * u.ms
)
self.last_spike_time = brainstate.ShortTermState(
u.math.full(self.varshape, -1e7 * u.ms)
)
if self.ref_var:
self.refractory = brainstate.ShortTermState(
braintools.init.param(braintools.init.Constant(False), self.varshape)
)
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Evaluate surrogate spike output from membrane voltage.
Parameters
----------
V : ArrayLike, optional
Voltage values with shape broadcastable to ``self.varshape`` and
units compatible with mV. If ``None``, uses current state
``self.V.value``.
Returns
-------
ArrayLike
Surrogate spike activation produced by
``spk_fun((V - V_th) / (V_th - V_reset))``.
"""
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
return self.spk_fun(v_scaled)
@staticmethod
def _parse_spike_events(spike_events: Iterable, v_shape):
"""Parse spike events into a list of (offset_ms, weight_np) tuples.
Parameters
----------
spike_events : Iterable or None
Off-grid spike events within this step.
v_shape : tuple
Target state shape for broadcasting weights.
Returns
-------
list of (float, np.ndarray)
Parsed events as (offset_in_ms, weight_array) pairs.
"""
events = []
if spike_events is None:
return events
dftype = brainstate.environ.dftype()
for ev in spike_events:
if isinstance(ev, dict):
offs = ev.get('offset', 0.0 * u.ms)
w = ev.get('weight', 0.0 * u.pA)
else:
offs, w = ev
off_ms = float(u.get_mantissa(offs / u.ms))
w_np = np.broadcast_to(
np.asarray(u.get_mantissa(w / u.pA), dtype=dftype),
v_shape,
)
events.append((off_ms, w_np))
return events
def _precompute_constants(self, h, v_shape, dftype, ditype):
"""Pre-compute constant numpy parameter arrays for use in update().
Caches all parameter-derived arrays that are invariant across
simulation steps (fixed dt, shape, and dtypes). Subsequent
calls to update() reuse the cached arrays, eliminating per-step
JAX dispatch overhead for parameter conversions.
Parameters
----------
h : float
Step size in ms.
v_shape : tuple
State array shape.
dftype : dtype
Float dtype for computations.
ditype : dtype
Integer dtype for step counters.
"""
_tnp = lambda x, unit: np.broadcast_to(
np.asarray(u.get_mantissa(x / unit), dtype=dftype), v_shape
)
E_L = _tnp(self.E_L, u.mV)
refr_steps = np.broadcast_to(
np.asarray(np.ceil(_tnp(self.t_ref, u.ms) / h), dtype=ditype), v_shape
)
if np.any(refr_steps < 1):
raise ValueError('Refractory time must be at least one time step.')
self._c_E_L = E_L
self._c_tau_m = _tnp(self.tau_m, u.ms)
self._c_tau_ex = _tnp(self.tau_syn_ex, u.ms)
self._c_tau_in = _tnp(self.tau_syn_in, u.ms)
self._c_c_m = _tnp(self.C_m, u.pF)
self._c_i_e = _tnp(self.I_e, u.pA)
self._c_u_th = _tnp(self.V_th - self.E_L, u.mV)
self._c_u_reset = _tnp(self.V_reset - self.E_L, u.mV)
self._c_u_min = -np.inf * np.ones(v_shape, dtype=dftype)
if self.V_min is not None:
self._c_u_min = _tnp(self.V_min - self.E_L, u.mV)
self._c_refr_steps = refr_steps
self._c_psc_norm_ex = np.e / self._c_tau_ex
self._c_psc_norm_in = np.e / self._c_tau_in
# Cache key
self._c_key = (h, v_shape, dftype, ditype)
@staticmethod
def _bisect_root(f, t_hi: float):
"""Find root of f in [0, t_hi] using bisection (64 iterations).
Parameters
----------
f : callable
Scalar function to find root of.
t_hi : float
Upper bound of search interval.
Returns
-------
float
Approximate root location.
"""
lo = 0.0
hi = float(t_hi)
f_lo = f(lo)
f_hi = f(hi)
if not np.isfinite(f_hi):
return hi
if f_lo > 0.0:
return 0.0
if f_hi <= 0.0:
return hi
for _ in range(64):
mid = 0.5 * (lo + hi)
f_mid = f(mid)
if f_mid > 0.0:
hi = mid
else:
lo = mid
return 0.5 * (lo + hi)
def _propagate_vectorized(self, dt_local, V_m, I_ex, dI_ex, I_in, dI_in,
y0, tau_m, tau_ex, tau_in, c_m, i_e, u_min,
is_refractory):
"""Propagate all state variables forward by dt_local (vectorized).
Parameters
----------
dt_local : np.ndarray
Local time step for each neuron.
V_m, I_ex, dI_ex, I_in, dI_in : np.ndarray
State variables.
y0 : np.ndarray
Buffered input current.
tau_m, tau_ex, tau_in, c_m, i_e : np.ndarray
Model parameters.
u_min : np.ndarray
Lower voltage clamp.
is_refractory : np.ndarray
Boolean refractory mask.
Returns
-------
tuple of np.ndarray
Updated (V_m, I_ex, dI_ex, I_in, dI_in).
"""
active = dt_local > 0.0
# Membrane propagation (only for non-refractory neurons).
expm1_tm = np.where(active, np.expm1(-dt_local / tau_m), 0.0)
P30 = np.where(active, -tau_m / c_m * expm1_tm, 0.0)
P31e, P32e = iaf_psc_alpha._alpha_propagator_p31_p32(tau_ex, tau_m, c_m, dt_local)
P31i, P32i = iaf_psc_alpha._alpha_propagator_p31_p32(tau_in, tau_m, c_m, dt_local)
V_candidate = (
P30 * (i_e + y0)
+ P31e * dI_ex
+ P32e * I_ex
+ P31i * dI_in
+ P32i * I_in
+ V_m * expm1_tm
+ V_m
)
V_candidate = np.maximum(V_candidate, u_min)
V_new = np.where(active & ~is_refractory, V_candidate, V_m)
# Synaptic state propagation (always, regardless of refractory).
exp_ex = np.where(active, np.exp(-dt_local / tau_ex), 1.0)
exp_in = np.where(active, np.exp(-dt_local / tau_in), 1.0)
I_ex_new = np.where(active, exp_ex * dt_local * dI_ex + exp_ex * I_ex, I_ex)
dI_ex_new = np.where(active, exp_ex * dI_ex, dI_ex)
I_in_new = np.where(active, exp_in * dt_local * dI_in + exp_in * I_in, I_in)
dI_in_new = np.where(active, exp_in * dI_in, dI_in)
return V_new, I_ex_new, dI_ex_new, I_in_new, dI_in_new
def _threshold_distance_vectorized(self, dt_local, V_before, I_ex_before,
dI_ex_before, I_in_before, dI_in_before,
y0, tau_m, tau_ex, tau_in, c_m, i_e, u_th):
"""Compute threshold distance after propagation by dt_local (vectorized).
Parameters
----------
dt_local : np.ndarray or float
Local time step.
V_before, I_ex_before, dI_ex_before, I_in_before, dI_in_before : np.ndarray
State variables before propagation.
y0 : np.ndarray
Buffered input current.
tau_m, tau_ex, tau_in, c_m, i_e : np.ndarray
Model parameters.
u_th : np.ndarray
Threshold in relative coordinates.
Returns
-------
np.ndarray
V(t + dt_local) - u_th for each neuron.
"""
expm1_tm = np.expm1(-dt_local / tau_m)
P30 = -tau_m / c_m * expm1_tm
P31e, P32e = iaf_psc_alpha._alpha_propagator_p31_p32(tau_ex, tau_m, c_m, dt_local)
P31i, P32i = iaf_psc_alpha._alpha_propagator_p31_p32(tau_in, tau_m, c_m, dt_local)
V_r = (
P30 * (i_e + y0)
+ P31e * dI_ex_before
+ P32e * I_ex_before
+ P31i * dI_in_before
+ P32i * I_in_before
+ V_before * expm1_tm
+ V_before
)
return V_r - u_th
def _bisect_vectorized(self, t_hi, V_before, I_ex_before, dI_ex_before,
I_in_before, dI_in_before, y0, tau_m, tau_ex,
tau_in, c_m, i_e, u_th, mask):
"""Vectorized bisection to find threshold crossing time.
Parameters
----------
t_hi : np.ndarray
Upper bound of search interval for each neuron.
V_before, I_ex_before, dI_ex_before, I_in_before, dI_in_before : np.ndarray
State variables before the ministep.
y0 : np.ndarray
Buffered input current.
tau_m, tau_ex, tau_in, c_m, i_e : np.ndarray
Model parameters.
u_th : np.ndarray
Threshold in relative coordinates.
mask : np.ndarray
Boolean mask of neurons to perform bisection on.
Returns
-------
np.ndarray
Approximate crossing times for each neuron (only valid where mask is True).
"""
lo = np.zeros_like(t_hi)
hi = t_hi.copy()
for _ in range(64):
mid = 0.5 * (lo + hi)
f_mid = self._threshold_distance_vectorized(
mid, V_before, I_ex_before, dI_ex_before,
I_in_before, dI_in_before, y0, tau_m, tau_ex, tau_in, c_m, i_e, u_th,
)
crossed = f_mid > 0.0
hi = np.where(mask & crossed, mid, hi)
lo = np.where(mask & ~crossed, mid, lo)
return 0.5 * (lo + hi)
[docs]
def update(self, x=0. * u.pA, spike_events=None):
r"""Advance one simulation step with optional precise within-step events.
Parameters
----------
x : ArrayLike, optional
Continuous external current in pA for the current global step.
Value is accumulated through :meth:`sum_current_inputs` and written
to ``self.y_input`` for use in the next step (one-step buffering).
Scalar or array-like broadcastable to ``self.V.value.shape``.
Default is ``0. * u.pA``.
spike_events : Iterable[tuple[Any, Any] | dict[str, Any]] or None, optional
Optional off-grid spike events within this ``dt`` step. Each item is
either ``(offset, weight)`` or ``{'offset': ..., 'weight': ...}``,
where ``offset`` is in ms from the right step edge and ``weight`` is
in pA. ``offset`` must satisfy ``0 <= offset <= dt``.
Positive weights target excitatory alpha derivative state; negative
weights target inhibitory alpha derivative state. ``None`` means no
extra within-step events. On-grid delta inputs collected from
:meth:`sum_delta_inputs` are still included at ``offset=0``.
Returns
-------
out : jax.Array
Spike output from :meth:`get_spike` with shape
``self.V.value.shape``. Values are surrogate spikes from
``self.spk_fun`` evaluated on threshold-scaled membrane potential
after precise-time integration and event handling.
Raises
------
ValueError
If computed refractory steps satisfy ``ceil(t_ref / dt) < 1`` or if
any event offset is outside ``[0, dt]``.
KeyError
If simulation context values ``t`` or ``dt`` are missing.
TypeError
If provided quantities are not unit-compatible with ms/pA during
conversion of ``x`` or ``spike_events``.
AttributeError
If called before required states are initialized via
:meth:`init_state`.
"""
t = brainstate.environ.get('t')
dt_q = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
h = float(u.get_mantissa(dt_q / u.ms))
t_ms = float(u.get_mantissa(t / u.ms))
step_idx = int(round(t_ms / h))
eps = np.finfo(np.float64).eps
v_shape = self.V.value.shape
# Use cached constant parameter arrays; recompute only when key changes.
if not hasattr(self, '_c_key') or self._c_key != (h, v_shape, dftype, ditype):
self._precompute_constants(h, v_shape, dftype, ditype)
E_L = self._c_E_L
tau_m = self._c_tau_m
tau_ex = self._c_tau_ex
tau_in = self._c_tau_in
c_m = self._c_c_m
i_e = self._c_i_e
u_th = self._c_u_th
u_reset = self._c_u_reset
u_min = self._c_u_min
refr_steps = self._c_refr_steps
psc_norm_ex = self._c_psc_norm_ex
psc_norm_in = self._c_psc_norm_in
# Convert per-step state arrays to unitless numpy.
_tn_state = lambda x, unit: np.broadcast_to(
np.asarray(u.get_mantissa(x / unit), dtype=dftype), v_shape
)
V_m = _tn_state(self.V.value, u.mV) - E_L
I_ex = _tn_state(self.I_syn_ex.value, u.pA)
dI_ex = np.broadcast_to(np.asarray(self.dI_syn_ex.value, dtype=dftype), v_shape)
I_in = _tn_state(self.I_syn_in.value, u.pA)
dI_in = np.broadcast_to(np.asarray(self.dI_syn_in.value, dtype=dftype), v_shape)
y_input = _tn_state(self.y_input.value, u.pA)
is_refractory = np.broadcast_to(
np.asarray(self.is_refractory.value, dtype=bool), v_shape
).copy()
last_spike_step = np.broadcast_to(
np.asarray(self.last_spike_step.value, dtype=ditype), v_shape
).copy()
last_spike_offset = _tn_state(self.last_spike_offset.value, u.ms).copy()
last_spike_time_prev = _tn_state(self.last_spike_time.value, u.ms).copy()
# Parse spike events and add on-grid delta inputs.
events = self._parse_spike_events(spike_events, v_shape)
on_grid = np.broadcast_to(
np.asarray(u.get_mantissa(self.sum_delta_inputs(0. * u.pA) / u.pA), dtype=dftype),
v_shape,
)
events.append((0.0, on_grid))
events.sort(key=lambda z: z[0], reverse=True)
for off, _ in events:
if off < 0.0 or off > h:
raise ValueError('All spike event offsets must satisfy 0 <= offset <= dt.')
# Current input for next step (one-step delay).
y_input_next = np.broadcast_to(
np.asarray(u.get_mantissa(self.sum_current_inputs(x, self.V.value) / u.pA), dtype=dftype),
v_shape,
)
# Working copies for mutation during event processing.
V_m = V_m.copy()
I_ex = I_ex.copy()
dI_ex = dI_ex.copy()
I_in = I_in.copy()
dI_in = dI_in.copy()
spike_mask = np.zeros(v_shape, dtype=bool)
# --- Handle neurons already above threshold at start of step ---
instant_spike = (~is_refractory) & (V_m >= u_th)
if np.any(instant_spike):
spike_off = h * (1.0 - eps)
last_spike_step = np.where(instant_spike, step_idx + 1, last_spike_step)
last_spike_offset = np.where(instant_spike, spike_off, last_spike_offset)
V_m = np.where(instant_spike, u_reset, V_m)
is_refractory = is_refractory | instant_spike
last_spike_time_prev = np.where(instant_spike, t_ms + h - spike_off, last_spike_time_prev)
spike_mask = spike_mask | instant_spike
# --- Build local events including refractory-release pseudo-event ---
# Determine which neurons need a refractory-release event.
refr_release = is_refractory & ((step_idx + 1 - last_spike_step) == refr_steps)
refr_release_offset = np.where(refr_release, last_spike_offset, -1.0)
# Combine external events with refractory-release events and sort.
# Process events from largest offset (step start) to smallest (step end).
all_offsets = [off for off, _ in events]
if np.any(refr_release):
unique_refr_offsets = np.unique(refr_release_offset[refr_release])
all_offsets = sorted(set(all_offsets) | set(unique_refr_offsets.tolist()), reverse=True)
else:
all_offsets = sorted(set(all_offsets), reverse=True)
# Build event lookup: for each offset, get the weight array (if any).
event_weight_map = {}
for off, w in events:
if off in event_weight_map:
event_weight_map[off] = event_weight_map[off] + w
else:
event_weight_map[off] = w.copy()
# Process all events in descending offset order.
last_off = np.full(v_shape, h, dtype=dftype)
for ev_off in all_offsets:
ministep = last_off - ev_off
# Propagate where ministep > 0.
propagate_mask = ministep > 0.0
if np.any(propagate_mask):
dt_local = np.where(propagate_mask, ministep, 0.0)
V_before = V_m.copy()
I_ex_before = I_ex.copy()
dI_ex_before = dI_ex.copy()
I_in_before = I_in.copy()
dI_in_before = dI_in.copy()
V_m, I_ex, dI_ex, I_in, dI_in = self._propagate_vectorized(
dt_local, V_m, I_ex, dI_ex, I_in, dI_in,
y_input, tau_m, tau_ex, tau_in, c_m, i_e, u_min, is_refractory,
)
# Check for threshold crossing.
crossed = propagate_mask & (~is_refractory) & (V_m >= u_th)
if np.any(crossed):
root = self._bisect_vectorized(
dt_local, V_before, I_ex_before, dI_ex_before,
I_in_before, dI_in_before, y_input, tau_m, tau_ex,
tau_in, c_m, i_e, u_th, crossed,
)
spike_off = h - ((h - last_off) + root)
spike_off = np.clip(spike_off, 0.0, h)
last_spike_step = np.where(crossed, step_idx + 1, last_spike_step)
last_spike_offset = np.where(crossed, spike_off, last_spike_offset)
V_m = np.where(crossed, u_reset, V_m)
is_refractory = is_refractory | crossed
last_spike_time_prev = np.where(crossed, t_ms + h - spike_off, last_spike_time_prev)
spike_mask = spike_mask | crossed
# Apply event: refractory release or synaptic weight.
is_refr_release_here = refr_release & (np.abs(refr_release_offset - ev_off) < 1e-15)
is_refractory = np.where(is_refr_release_here, False, is_refractory)
if ev_off in event_weight_map:
ev_w = event_weight_map[ev_off]
# Non-refractory-release neurons get synaptic input.
apply_weight = ~is_refr_release_here
dI_ex = np.where(apply_weight & (ev_w >= 0.0), dI_ex + psc_norm_ex * ev_w, dI_ex)
dI_in = np.where(apply_weight & (ev_w < 0.0), dI_in + psc_norm_in * ev_w, dI_in)
last_off = np.where(propagate_mask | is_refr_release_here, ev_off, last_off)
# --- Final propagation from last event to step end ---
final_ministep = last_off
propagate_final = final_ministep > 0.0
if np.any(propagate_final):
dt_local = np.where(propagate_final, final_ministep, 0.0)
V_before = V_m.copy()
I_ex_before = I_ex.copy()
dI_ex_before = dI_ex.copy()
I_in_before = I_in.copy()
dI_in_before = dI_in.copy()
V_m, I_ex, dI_ex, I_in, dI_in = self._propagate_vectorized(
dt_local, V_m, I_ex, dI_ex, I_in, dI_in,
y_input, tau_m, tau_ex, tau_in, c_m, i_e, u_min, is_refractory,
)
# Check for threshold crossing in final segment.
crossed = propagate_final & (~is_refractory) & (V_m >= u_th)
if np.any(crossed):
root = self._bisect_vectorized(
dt_local, V_before, I_ex_before, dI_ex_before,
I_in_before, dI_in_before, y_input, tau_m, tau_ex,
tau_in, c_m, i_e, u_th, crossed,
)
spike_off = h - ((h - last_off) + root)
spike_off = np.clip(spike_off, 0.0, h)
last_spike_step = np.where(crossed, step_idx + 1, last_spike_step)
last_spike_offset = np.where(crossed, spike_off, last_spike_offset)
V_m = np.where(crossed, u_reset, V_m)
is_refractory = is_refractory | crossed
last_spike_time_prev = np.where(crossed, t_ms + h - spike_off, last_spike_time_prev)
spike_mask = spike_mask | crossed
# Construct spike output voltage for surrogate gradient.
v_for_spike = np.where(spike_mask, u_th + 1e-12, np.minimum(V_m, u_th - 1e-12))
# Write back state.
self.y_input.value = y_input_next * u.pA
self.I_syn_ex.value = I_ex * u.pA
self.dI_syn_ex.value = dI_ex
self.I_syn_in.value = I_in * u.pA
self.dI_syn_in.value = dI_in
self.V.value = (V_m + E_L) * u.mV
self.is_refractory.value = jnp.asarray(is_refractory, dtype=bool)
self.last_spike_step.value = jnp.asarray(last_spike_step, dtype=ditype)
self.last_spike_offset.value = last_spike_offset * u.ms
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time_prev * u.ms)
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(self.is_refractory.value)
return self.get_spike((v_for_spike + E_L) * u.mV)