# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
from ._utils import is_tracer
__all__ = [
'iaf_psc_delta_ps',
]
class iaf_psc_delta_ps(NESTNeuron):
r"""NEST-compatible ``iaf_psc_delta_ps`` with precise spike timing.
Description
-----------
``iaf_psc_delta_ps`` is a current-based leaky integrate-and-fire neuron
with delta-shaped synaptic jumps (weights in mV), exact linear
subthreshold integration, and precise off-grid spike timing inside each
global simulation step. The implementation follows NEST
``models/iaf_psc_delta_ps.{h,cpp}`` semantics, including event ordering by
within-step offsets, analytic threshold-crossing localization for
current-driven spikes, and optional accumulation of refractory-time inputs.
**1. Linear Membrane Dynamics and Exact Closed-Form Propagator**
The subthreshold membrane potential dynamics are
.. math::
\frac{dV_m}{dt} = -\frac{V_m - E_L}{\tau_m}
+ \frac{I_\mathrm{ext}(t) + I_e}{C_m},
with piecewise-constant :math:`I_\mathrm{ext}` over each simulation step.
Defining :math:`U = V_m - E_L`, :math:`R = \tau_m / C_m`, and a constant
current over an interval :math:`\Delta t`, exact integration gives
.. math::
U(t + \Delta t) = U(t)e^{-\Delta t/\tau_m}
+ R(I_\mathrm{ext}+I_e)\left(1 - e^{-\Delta t/\tau_m}\right).
The code evaluates this update with ``expm1``-based algebra for numerical
stability when :math:`\Delta t/\tau_m` is small, which reduces
cancellation error in fine-step simulations.
**2. Spike Generation Mechanisms and Precise Spike-Time Derivation**
Two spike mechanisms are implemented:
- **Instantaneous event-driven spikes**: if an incoming delta event at
offset :math:`\delta` pushes :math:`U \ge U_{th}`, spike time is the
event time exactly.
- **Current-driven spikes**: if propagation yields :math:`U \ge U_{th}`,
spike offset is solved analytically from the exact trajectory:
.. math::
\Delta t_\mathrm{cross}
= -\tau_m \log\frac{V_\infty - U}{V_\infty - U_{th}},
\quad
V_\infty = R(I_\mathrm{ext}+I_e).
The model stores:
- ``last_spike_time``: absolute spike time in ms,
- ``last_spike_offset``: off-grid offset relative to the right border of
the current grid step (NEST semantics),
- ``last_spike_step``: on-grid step index used internally for refractory logic.
**3. Refractory Handling and Deferred Refractory-Input Accumulation**
After a spike, membrane potential is reset to ``V_reset`` and clamped during
the absolute refractory period.
In NEST ``iaf_psc_delta_ps``, refractory duration in steps is derived as
``floor(t_ref / dt)`` (via ``Time(...).get_steps()``) and must be at least one
simulation step. This implementation enforces the same runtime constraint.
By default, spikes arriving during refractory are discarded. If
``refractory_input=True``, they are accumulated and exponentially damped until
end of refractoriness, then applied once at refractory release, matching NEST.
**4. Event Ordering, Assumptions, Constraints, and Computational Implications**
For each simulation step the update proceeds as follows:
1. Optional immediate spike if state starts super-threshold.
2. Process within-step events in offset order (start to end of step):
- propagate to event time (if non-refractory),
- check current-driven crossing,
- apply event jump and check instant crossing.
3. Propagate remaining interval (if any).
4. Store new external current input buffer for next step.
Assumptions and constraints used by the implementation:
- Parameter tensors are scalar or broadcastable to ``self.varshape``.
- Required physical inequalities are validated at construction:
``V_reset < V_th``, ``C_m > 0``, ``tau_m > 0``, ``t_ref >= 0``, and if
``V_min`` is provided then ``V_reset >= V_min``.
- Runtime requires ``floor(t_ref / dt) >= 1`` and ``dt > 0``.
- Every precise event offset must satisfy ``0 <= offset <= dt``.
Computationally, the update iterates scalar-wise over ``np.ndindex``
across the full state shape and processes all local events in each cell,
so cost is :math:`O(|\mathrm{state}| \cdot K)` per step for ``K`` events
(excluding input aggregation).
Parameters
----------
in_size : Size
Population shape specification used to derive ``self.varshape``.
Scalar integer for 1D populations or tuple for multi-dimensional.
E_L : ArrayLike, optional
Resting membrane potential :math:`E_L` in mV. Scalar or array-like
broadcastable to ``self.varshape``. Default is ``-70. * u.mV``.
C_m : ArrayLike, optional
Membrane capacitance :math:`C_m` in pF, broadcastable to
``self.varshape``. Must be strictly positive elementwise.
Default is ``250. * u.pF``.
tau_m : ArrayLike, optional
Membrane time constant :math:`\tau_m` in ms, broadcastable to
``self.varshape``. Must be strictly positive elementwise.
Default is ``10. * u.ms``.
t_ref : ArrayLike, optional
Absolute refractory duration :math:`t_{ref}` in ms, broadcastable to
``self.varshape``. At runtime converted to steps by
``floor(t_ref / dt)`` and must produce at least one step.
Default is ``2. * u.ms``.
V_th : ArrayLike, optional
Spike threshold :math:`V_{th}` in mV, broadcastable to ``self.varshape``.
Default is ``-55. * u.mV``.
V_reset : ArrayLike, optional
Reset potential :math:`V_{reset}` in mV, broadcastable to
``self.varshape``. Must satisfy ``V_reset < V_th`` elementwise.
Default is ``-70. * u.mV``.
I_e : ArrayLike, optional
Constant external current :math:`I_e` in pA, broadcastable to
``self.varshape``. Added to buffered current each propagation segment.
Default is ``0. * u.pA``.
V_min : ArrayLike or None, optional
Optional lower membrane bound :math:`V_{min}` in mV, broadcastable to
``self.varshape``. ``None`` disables lower clipping (uses ``-inf``).
Default is ``None``.
V_initializer : Callable, optional
Initializer used by :meth:`init_state` to create membrane state ``V``.
Must return values unit-compatible with mV and shape-compatible with
``self.varshape`` (and optional batch prefix). Default is
``braintools.init.Constant(-70. * u.mV)``.
spk_fun : Callable, optional
Surrogate spike nonlinearity used by :meth:`get_spike` and returned by
:meth:`update`. Receives normalized threshold distance tensor.
Default is ``braintools.surrogate.ReluGrad()``.
spk_reset : str, optional
Reset mode forwarded to :class:`~brainpy_state._base.Neuron`.
``'hard'`` matches NEST hard-reset behavior. Default is ``'hard'``.
refractory_input : bool, optional
If ``False``, delta events received during refractory are ignored.
If ``True``, they are exponentially weighted into
``refractory_spike_buffer`` and applied at refractory release.
Default is ``False``.
ref_var : bool, optional
If ``True``, exposes additional state ``self.refractory`` mirroring
``self.is_refractory`` for introspection. Default is ``False``.
name : str or None, optional
Optional node name.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to model symbols
:header-rows: 1
:widths: 17 28 14 16 35
* - Parameter
- Type / shape / unit
- Default
- Math symbol
- Semantics
* - ``in_size``
- :class:`~brainstate.typing.Size`; scalar/tuple
- required
- --
- Defines ``self.varshape`` for all state/parameter broadcasts.
* - ``E_L``
- ArrayLike, broadcastable to ``self.varshape`` (mV)
- ``-70. * u.mV``
- :math:`E_L`
- Resting membrane potential and origin of transformed state ``U``.
* - ``C_m``
- ArrayLike, broadcastable (pF), ``> 0``
- ``250. * u.pF``
- :math:`C_m`
- Converts current to membrane-rate contribution.
* - ``tau_m``
- ArrayLike, broadcastable (ms), ``> 0``
- ``10. * u.ms``
- :math:`\tau_m`
- Leak/relaxation time constant in exact propagator.
* - ``t_ref``
- ArrayLike, broadcastable (ms), runtime ``floor(t_ref/dt) >= 1``
- ``2. * u.ms``
- :math:`t_{ref}`
- Absolute refractory duration.
* - ``V_th`` and ``V_reset``
- ArrayLike, broadcastable (mV), with ``V_reset < V_th``
- ``-55. * u.mV``, ``-70. * u.mV``
- :math:`V_{th}`, :math:`V_{reset}`
- Threshold and post-spike reset levels.
* - ``I_e``
- ArrayLike, broadcastable (pA)
- ``0. * u.pA``
- :math:`I_e`
- Constant injected current term.
* - ``V_min``
- ArrayLike broadcastable (mV) or ``None``
- ``None``
- :math:`V_{min}`
- Optional lower clip on membrane potential.
* - ``V_initializer``
- Callable returning mV-compatible values
- ``Constant(-70. * u.mV)``
- --
- Initializes membrane state ``V``.
* - ``spk_fun``
- Callable
- ``ReluGrad()``
- --
- Surrogate spike output function.
* - ``spk_reset``
- str
- ``'hard'``
- --
- Reset policy inherited from base neuron class.
* - ``refractory_input``
- bool
- ``False``
- --
- Controls treatment of refractory-time delta events.
* - ``ref_var``
- bool
- ``False``
- --
- Exposes persistent refractory state variable.
* - ``name``
- str | None
- ``None``
- --
- Optional node identifier.
Raises
------
ValueError
If validated construction/runtime constraints fail, including invalid
parameter inequalities (for example ``V_reset >= V_th``), non-positive
time constants/capacitance, ``dt <= 0``, invalid event offsets, or
``floor(t_ref / dt) < 1``.
TypeError
If provided arguments are incompatible with expected unit arithmetic
(mV, pA, pF, ms) or callable interfaces.
KeyError
If required simulation context entries (``t`` and/or ``dt``) are
missing when :meth:`update` is called.
AttributeError
If :meth:`update` is called before :meth:`init_state` creates required
state variables.
Attributes
----------
V : brainstate.HiddenState
Membrane potential state in mV, shape ``self.varshape`` (or with
leading batch dimension when ``batch_size`` is specified).
I_stim : brainstate.ShortTermState
One-step buffered continuous current input in pA. Applied in the
*next* update call (NEST ring-buffer semantics).
last_spike_time : brainstate.ShortTermState
Absolute precise spike time (ms) for the latest emitted spike.
Initialized to ``-1e7 * u.ms`` (far past) to indicate no prior spike.
last_spike_step : brainstate.ShortTermState
Integer (``jnp.int32``) step index associated with the latest emitted
spike. Initialized to ``-1``.
last_spike_offset : brainstate.ShortTermState
Precise within-step offset (ms) measured from the step right boundary
(NEST convention: ``0`` at step end, ``dt`` at step start).
is_refractory : brainstate.ShortTermState
Boolean mask indicating which neurons are currently in the absolute
refractory period.
refractory_spike_buffer : brainstate.ShortTermState
Deferred refractory-time delta contribution (mV). Non-zero only when
``refractory_input=True``; accumulates exponentially decayed delta
events and is released at end of refractoriness.
refractory : brainstate.ShortTermState
Mirror of ``is_refractory`` exposed for external inspection. Present
only when ``ref_var=True``.
Notes
-----
- ``x`` passed to ``update(x=...)`` is buffered into ``I_stim`` and applied
on the *next* step, mirroring NEST ring-buffer semantics for current events.
- Delta inputs from ``add_delta_input`` are interpreted as on-grid events at
step end (offset ``0``).
- Additional within-step precise events can be supplied through
``update(spike_events=...)`` where each event is ``(offset, weight)``
or ``{'offset': ..., 'weight': ...}`` in units of ms and mV.
- This model uses ``floor(t_ref / dt)`` for refractory step conversion
(matching NEST ``iaf_psc_delta_ps``), whereas ``iaf_psc_delta`` uses
``ceil(t_ref / dt)``.
Examples
--------
Basic usage with constant current drive:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neu = brainpy.state.iaf_psc_delta_ps(in_size=2, t_ref=2.0 * u.ms)
... neu.init_state()
... with brainstate.environ.context(t=0.0 * u.ms):
... spk = neu.update(x=200.0 * u.pA)
... _ = spk.shape
Precise within-step spike events with ``refractory_input=True``:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neu = brainpy.state.iaf_psc_delta_ps(in_size=1, refractory_input=True)
... neu.init_state()
... with brainstate.environ.context(t=0.0 * u.ms):
... _ = neu.update(spike_events=[(0.04 * u.ms, 2.5 * u.mV)])
References
----------
.. [1] Rotter S, Diesmann M (1999). Exact simulation of time-invariant linear
systems with applications to neuronal modeling. Biological Cybernetics
81:381-402. DOI: https://doi.org/10.1007/s004220050570
.. [2] Diesmann M, Gewaltig M-O, Rotter S, Aertsen A (2001). State space
analysis of synchronous spiking in cortical neural networks.
Neurocomputing 38-40:565-571.
DOI: https://doi.org/10.1016/S0925-2312(01)00409-X
.. [3] Morrison A, Straube S, Plesser HE, Diesmann M (2007).
Exact subthreshold integration with continuous spike times in
discrete-time neural network simulations. Neural Computation
19(1):47-79. DOI: https://doi.org/10.1162/neco.2007.19.1.47
.. [4] Hanuschkin A, Kunkel S, Helias M, Morrison A, Diesmann M (2010).
A general and efficient method for incorporating exact spike times in
globally time-driven simulations. Frontiers in Neuroinformatics 4:113.
DOI: https://doi.org/10.3389/fninf.2010.00113
See Also
--------
iaf_psc_delta : Current-based LIF with delta synapses (on-grid spike times)
iaf_cond_exp : Conductance-based LIF with exponential synapses
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
E_L: ArrayLike = -70. * u.mV,
C_m: ArrayLike = 250. * u.pF,
tau_m: ArrayLike = 10. * u.ms,
t_ref: ArrayLike = 2. * u.ms,
V_th: ArrayLike = -55. * u.mV,
V_reset: ArrayLike = -70. * u.mV,
I_e: ArrayLike = 0. * u.pA,
V_min: Optional[ArrayLike] = None,
V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
refractory_input: bool = False,
ref_var: bool = False,
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
self.E_L = braintools.init.param(E_L, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.tau_m = braintools.init.param(tau_m, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.V_th = braintools.init.param(V_th, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
self.V_min = None if V_min is None else braintools.init.param(V_min, self.varshape)
self.V_initializer = V_initializer
self.refractory_input = refractory_input
self.ref_var = ref_var
self._validate_parameters()
# Precompute refractory step count (uses floor, matching NEST iaf_psc_delta_ps).
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.refr_steps = u.math.asarray(u.math.floor(self.t_ref / dt), dtype=ditype)
def _validate_parameters(self):
r"""Validate model parameters against NEST constraints.
Raises
------
ValueError
If parameter inequalities or positivity constraints are violated.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.V_reset, self.C_m, self.tau_m)):
return
if np.any(self.V_reset >= self.V_th):
raise ValueError('Reset potential must be smaller than threshold.')
if self.V_min is not None and np.any(self.V_reset < self.V_min):
raise ValueError('Reset potential must be greater or equal to minimum potential.')
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.tau_m <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory time must not be negative.')
[docs]
def init_state(self, batch_size=None, **kwargs):
r"""Initialize membrane, timing, and refractory runtime states.
Parameters
----------
batch_size : int or None, optional
Optional batch dimension prepended to ``self.varshape`` for all
state arrays. ``None`` keeps unbatched state. Default is ``None``.
**kwargs
Unused compatibility parameters accepted by the base-state API.
Raises
------
ValueError
If initializer outputs cannot be broadcast to target state shape.
TypeError
If initializer values are not unit-compatible with mV/pA/ms states.
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
batch_shape = ((batch_size,) + tuple(self.varshape)) if batch_size is not None else self.varshape
V = braintools.init.param(self.V_initializer, self.varshape, batch_size)
self.V = brainstate.HiddenState(V)
self.I_stim = brainstate.ShortTermState(u.math.zeros(batch_shape, dtype=dftype) * u.pA)
self.last_spike_time = brainstate.ShortTermState(u.math.full(batch_shape, -1e7 * u.ms))
self.last_spike_step = brainstate.ShortTermState(u.math.full(batch_shape, -1, dtype=ditype))
self.last_spike_offset = brainstate.ShortTermState(u.math.zeros(batch_shape, dtype=dftype) * u.ms)
self.is_refractory = brainstate.ShortTermState(
braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size)
)
self.refractory_spike_buffer = brainstate.ShortTermState(
u.math.zeros(batch_shape, dtype=dftype) * u.mV
)
if self.ref_var:
refractory = braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size)
self.refractory = brainstate.ShortTermState(refractory)
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Evaluate surrogate spike activation for a voltage tensor.
Parameters
----------
V : ArrayLike or None, optional
Voltage values in mV, broadcast-compatible with ``self.varshape``
(or current state shape when batched). If ``None``, uses
``self.V.value``.
Returns
-------
out : dict
Output of ``self.spk_fun`` evaluated on normalized threshold
distance ``(V - V_th) / (V_th - V_reset)`` with same shape as ``V``.
Raises
------
TypeError
If ``V`` cannot participate in unit-compatible arithmetic.
"""
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
return self.spk_fun(v_scaled)
@staticmethod
def _canonicalize_spike_events(
spike_events: Optional[Union[Sequence, dict, Tuple]],
) -> Sequence:
r"""Normalize accepted spike-event container variants.
Parameters
----------
spike_events : Sequence or dict or tuple or None
Event specification accepted by :meth:`update`:
``None``, a single ``{'offset', 'weight'}`` dict, one
``(offset, weight)`` tuple, or a sequence of these entries.
Returns
-------
out : Sequence
Sequence-like iterable of event records. Single dict/tuple inputs
are wrapped into a one-element list; ``None`` returns ``[]``.
"""
if spike_events is None:
return []
if isinstance(spike_events, dict):
return [spike_events]
if isinstance(spike_events, tuple) and len(spike_events) == 2:
return [spike_events]
return spike_events
def _parse_spike_events(
self,
spike_events: Optional[Union[Sequence, dict, Tuple]],
shape,
) -> Sequence[Tuple[float, np.ndarray]]:
r"""Parse precise spike events into numeric offsets and broadcast weights.
Parameters
----------
spike_events : Sequence or dict or tuple or None
Event specification in one of these forms:
``(offset, weight)``, ``{'offset': ..., 'weight': ...}``, or
sequence of such entries. ``offset`` is interpreted in ms and
``weight`` in mV; plain numeric values are promoted to these units.
shape : tuple of int
Target state shape used to broadcast each event weight.
Returns
-------
out : Sequence[Tuple[float, np.ndarray]]
List of parsed events ``[(offset_ms, weight_np), ...]`` where
``offset_ms`` is ``float`` and ``weight_np`` is a ``float64``
``numpy.ndarray`` broadcast to ``shape`` (unit: mV).
Raises
------
ValueError
If an event dictionary does not contain both ``offset`` and
``weight``, or if an event record has unsupported structure.
TypeError
If offsets/weights are not convertible to ms/mV-compatible arrays.
"""
dftype = brainstate.environ.dftype()
parsed = []
for ev in self._canonicalize_spike_events(spike_events):
if isinstance(ev, dict):
if 'offset' not in ev or 'weight' not in ev:
raise ValueError('Each spike event dict must contain "offset" and "weight".')
offset, weight = ev['offset'], ev['weight']
else:
if not isinstance(ev, Iterable):
raise ValueError(f'Unsupported spike event format: {ev}.')
offset, weight = ev
offset_ms = float(
u.math.asarray((offset if not isinstance(offset, (int, float)) else offset * u.ms) / u.ms))
weight_q = weight if not isinstance(weight, (int, float)) else weight * u.mV
weight_np = np.broadcast_to(
np.asarray(u.math.asarray(weight_q / u.mV), dtype=dftype), shape
)
parsed.append((offset_ms, weight_np))
return parsed
def _update_jax(self, x, t_q, dt_q):
"""JAX-vectorized update step for JIT-compatible simulation (no spike_events).
Uses JAX operations throughout — no ``float()`` calls on traced values —
so this method can be used inside ``brainstate.transform.for_loop``.
Handles: super-threshold start, refractory clamping, refractory release,
current-driven precise spike timing, and on-grid delta events at step end.
``refractory_input`` buffering is not supported in this path.
"""
dt_ms = u.math.asarray(dt_q / u.ms)
t_ms = u.math.asarray(t_q / u.ms)
step_idx = jnp.round(t_ms / dt_ms).astype(jnp.int32)
E_L = u.math.asarray(self.E_L / u.mV)
tau_m = u.math.asarray(self.tau_m / u.ms)
C_m = u.math.asarray(self.C_m / u.pF)
I_e = u.math.asarray(self.I_e / u.pA)
U_th = u.math.asarray((self.V_th - self.E_L) / u.mV)
U_reset = u.math.asarray((self.V_reset - self.E_L) / u.mV)
r_mem = tau_m / C_m
U = u.math.asarray(self.V.value / u.mV) - E_L
v_shape = U.shape # (varshape) or (batch_size,) + varshape
I_stim = u.math.asarray(self.I_stim.value / u.pA)
is_refr = self.is_refractory.value
last_step = self.last_spike_step.value
last_off = u.math.asarray(self.last_spike_offset.value / u.ms)
last_spike_t = u.math.asarray(self.last_spike_time.value / u.ms)
# Broadcast to v_shape so scan carry shapes remain stable (sum_* may return
# 0-d scalars when no projections are connected).
new_i_stim = jnp.broadcast_to(
u.math.asarray(self.sum_current_inputs(x, self.V.value) / u.pA), v_shape
)
on_grid_delta = jnp.broadcast_to(
u.math.asarray(self.sum_delta_inputs(0. * u.mV) / u.mV), v_shape
)
steps_since_spike = (step_idx + 1) - last_step
release_now = is_refr & jnp.equal(steps_since_spike, self.refr_steps)
still_refr = is_refr & ~release_now
v_inf = r_mem * (I_stim + I_e)
# Super-threshold at step start (spike before any events)
super_thresh = (~is_refr) & (U >= U_th)
# Propagation duration:
# super_thresh or still_refr → 0 (no propagation)
# release_now → last_off (remaining time in step after refractory ends)
# otherwise → dt_ms (full step)
prop_dt = jnp.where(
super_thresh | still_refr,
jnp.zeros_like(dt_ms),
jnp.where(release_now, last_off, dt_ms),
)
expm1_prop = jnp.expm1(-prop_dt / tau_m)
U_start = jnp.where(super_thresh, U_reset, U)
# NEST-stable propagator: u_new = -v_inf * expm1 + u * expm1 + u
U_prop = -v_inf * expm1_prop + U_start * expm1_prop + U_start
if self.V_min is not None:
U_min = u.math.asarray((self.V_min - self.E_L) / u.mV)
U_prop = jnp.maximum(U_prop, U_min)
# Current-driven spike: can fire if not super_thresh and not still_refr
can_spike = ~super_thresh & ~still_refr
current_spike = can_spike & (U_prop >= U_th)
# Precise spike offset for current-driven spikes (from step right edge)
tiny = jnp.finfo(U_prop.dtype).tiny
safe_denom = jnp.where(
jnp.abs(v_inf - U_th) > 1e-10,
v_inf - U_th,
jnp.where(v_inf >= U_th, jnp.full_like(v_inf, 1e-10), jnp.full_like(v_inf, -1e-10)),
)
safe_ratio = jnp.clip((v_inf - U_prop) / safe_denom, tiny, 1.0)
spike_off_current = jnp.clip(-tau_m * jnp.log(safe_ratio), 0.0, dt_ms)
spike_time_current = t_ms + dt_ms - spike_off_current
# On-grid delta event at step end (offset = 0)
spiked_before_delta = super_thresh | current_spike
U_with_delta = jnp.where(spiked_before_delta | still_refr, U_prop, U_prop + on_grid_delta)
delta_spike = ~spiked_before_delta & ~still_refr & (U_with_delta >= U_th)
spike_time_delta = t_ms + dt_ms
# Super-threshold spike: at t + dt - epsilon (just before step end)
_EPS = np.finfo(np.float64).eps
spike_off_super = dt_ms * _EPS
spike_time_super = t_ms + dt_ms - spike_off_super
spiked = super_thresh | current_spike | delta_spike
spike_off_chosen = jnp.where(
super_thresh, spike_off_super,
jnp.where(current_spike, spike_off_current, jnp.zeros_like(last_off)),
)
spike_time_chosen = jnp.where(
super_thresh, spike_time_super,
jnp.where(current_spike, spike_time_current, spike_time_delta),
)
new_spike_time = jax.lax.stop_gradient(
jnp.where(spiked, spike_time_chosen, last_spike_t)
)
U_final = jnp.where(spiked, U_reset, U_with_delta)
V_final = U_final + E_L
V_for_spike = jnp.where(spiked, U_th + E_L + 1e-12, V_final)
new_is_refr = jnp.where(spiked, True, is_refr & ~release_now)
new_last_step = jnp.where(spiked, (step_idx + 1).astype(last_step.dtype), last_step)
new_last_off = jnp.where(spiked, spike_off_chosen, last_off)
self.V.value = V_final * u.mV
self.I_stim.value = new_i_stim * u.pA
self.last_spike_step.value = new_last_step
self.last_spike_offset.value = new_last_off * u.ms
self.is_refractory.value = new_is_refr
self.refractory_spike_buffer.value = jnp.zeros_like(V_final) * u.mV
self.last_spike_time.value = new_spike_time * u.ms
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(new_is_refr)
return self.get_spike(V_for_spike * u.mV)
[docs]
def update(self, x=0. * u.pA, spike_events: Optional[Union[Sequence, dict, Tuple]] = None):
r"""Advance one simulation step with optional precise within-step events.
Parameters
----------
x : ArrayLike, optional
External current input in pA. This value is buffered into
``self.I_stim`` and applied in the *next* update call, matching
NEST ring-buffer current semantics.
spike_events : Sequence or dict or tuple or None, optional
Optional precise delta events applied in the current step.
Accepted formats are ``(offset, weight)``,
``{'offset': ..., 'weight': ...}``, or a sequence of such events.
``offset`` is in ms measured from the step right boundary with
NEST convention (``0`` at step end, ``dt`` at step start).
``weight`` is a voltage jump in mV and may be scalar or
broadcastable to neuron state shape.
Returns
-------
out : jax.Array
Surrogate spike output from :meth:`get_spike` with shape
``self.V.value.shape``. Elements corresponding to neurons that
spiked in this step are forced slightly above threshold before
surrogate evaluation to encode emitted spikes after hard reset.
Raises
------
ValueError
If ``dt <= 0``, if ``floor(t_ref / dt) < 1``, if event offsets are
outside ``[0, dt]``, or if event structures are invalid.
KeyError
If simulation context does not provide required ``t``/``dt``.
AttributeError
If state variables are unavailable because :meth:`init_state` was
not called before :meth:`update`.
TypeError
If inputs or internal values are not unit-compatible with expected
pA/mV/ms arithmetic.
"""
t_q = brainstate.environ.get('t')
dt_q = brainstate.environ.get_dt()
# Dispatch to JAX-vectorized path when under JIT (t and dt are traced values).
# This check MUST come before any float() call on environment values.
t_raw = u.math.asarray(t_q / u.ms)
if spike_events is None and is_tracer(t_raw):
return self._update_jax(x, t_q, dt_q)
# Python path: float() conversions are safe here (not under JIT).
dt_ms = float(u.math.asarray(dt_q / u.ms))
t_ms = float(t_raw)
if dt_ms <= 0.0:
raise ValueError('Simulation time step must be positive.')
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
v_shape = self.V.value.shape
# Convert all parameters to unitless numpy arrays.
E_L = np.broadcast_to(np.asarray(u.math.asarray(self.E_L / u.mV), dtype=dftype), v_shape)
V = np.broadcast_to(np.asarray(u.math.asarray(self.V.value / u.mV), dtype=dftype), v_shape)
U = V - E_L
C_m = np.broadcast_to(np.asarray(u.math.asarray(self.C_m / u.pF), dtype=dftype), v_shape)
tau_m = np.broadcast_to(np.asarray(u.math.asarray(self.tau_m / u.ms), dtype=dftype), v_shape)
t_ref = np.broadcast_to(np.asarray(u.math.asarray(self.t_ref / u.ms), dtype=dftype), v_shape)
U_th = np.broadcast_to(np.asarray(u.math.asarray((self.V_th - self.E_L) / u.mV), dtype=dftype), v_shape)
U_reset = np.broadcast_to(np.asarray(u.math.asarray((self.V_reset - self.E_L) / u.mV), dtype=dftype), v_shape)
U_min = -np.inf * np.ones(v_shape, dtype=dftype)
if self.V_min is not None:
U_min = np.broadcast_to(
np.asarray(u.math.asarray((self.V_min - self.E_L) / u.mV), dtype=dftype), v_shape
)
I_e = np.broadcast_to(np.asarray(u.math.asarray(self.I_e / u.pA), dtype=dftype), v_shape)
I_stim = np.broadcast_to(np.asarray(u.math.asarray(self.I_stim.value / u.pA), dtype=dftype), v_shape)
last_step = np.broadcast_to(
np.asarray(u.math.asarray(self.last_spike_step.value), dtype=ditype), v_shape
)
last_offset = np.broadcast_to(
np.asarray(u.math.asarray(self.last_spike_offset.value / u.ms), dtype=dftype), v_shape
)
is_refractory = np.broadcast_to(
np.asarray(u.math.asarray(self.is_refractory.value), dtype=bool), v_shape
)
refr_buffer = np.broadcast_to(
np.asarray(u.math.asarray(self.refractory_spike_buffer.value / u.mV), dtype=dftype), v_shape
)
last_spike_time_prev = np.broadcast_to(
np.asarray(u.math.asarray(self.last_spike_time.value / u.ms), dtype=dftype), v_shape
)
refr_steps = np.floor(t_ref / dt_ms).astype(np.int64)
if np.any(refr_steps < 1):
raise ValueError('Refractory time must be at least one time step.')
on_grid_delta = np.broadcast_to(
np.asarray(u.math.asarray(self.sum_delta_inputs(0. * u.mV) / u.mV), dtype=dftype), v_shape
)
new_i_stim = np.broadcast_to(
np.asarray(u.math.asarray(self.sum_current_inputs(x, self.V.value) / u.pA), dtype=dftype), v_shape
)
parsed_events = self._parse_spike_events(spike_events, v_shape)
parsed_events.append((0.0, on_grid_delta))
parsed_events = sorted(parsed_events, key=lambda z: z[0], reverse=True)
if any((ev_off < 0.0 or ev_off > dt_ms) for ev_off, _ in parsed_events):
raise ValueError('All spike event offsets must satisfy 0 <= offset <= dt.')
step_idx = int(round(t_ms / dt_ms))
eps = np.finfo(np.float64).eps
V_next = np.empty(v_shape, dtype=np.float64)
last_step_next = np.empty_like(last_step)
last_offset_next = np.empty_like(last_offset)
is_refractory_next = np.empty_like(is_refractory)
refr_buffer_next = np.empty_like(refr_buffer)
last_spike_time_next = np.empty_like(last_spike_time_prev)
spike_mask = np.zeros_like(V, dtype=bool)
V_for_spike = np.empty_like(V)
for idx in np.ndindex(v_shape):
u_i = U[idx]
i_i = I_stim[idx]
tau_i = tau_m[idx]
t_ref_i = t_ref[idx]
c_m_i = C_m[idx]
i_e_i = I_e[idx]
u_th_i = U_th[idx]
u_reset_i = U_reset[idx]
u_min_i = U_min[idx]
refr_steps_i = int(refr_steps[idx])
last_step_i = int(last_step[idx])
last_offset_i = float(last_offset[idx])
is_refr_i = bool(is_refractory[idx])
refr_buf_i = float(refr_buffer[idx])
spike_time_i = float(last_spike_time_prev[idx])
r_mem = tau_i / c_m_i
did_spike = False
def _propagate(delta_t_ms: float):
nonlocal u_i
if delta_t_ms <= 0.0:
return
expm1_dt = math.expm1(-delta_t_ms / tau_i)
v_inf = r_mem * (i_i + i_e_i)
# Numerically stable arrangement used in NEST.
u_i = -v_inf * expm1_dt + u_i * expm1_dt + u_i
def _emit_spike(offset_u_ms: float):
nonlocal did_spike, last_step_i, last_offset_i, is_refr_i, u_i, spike_time_i
v_inf = r_mem * (i_i + i_e_i)
ratio = (v_inf - u_i) / (v_inf - u_th_i)
ratio = min(1.0, max(np.finfo(np.float64).tiny, ratio))
dt_cross = -tau_i * math.log(ratio)
spike_off = offset_u_ms + dt_cross
spike_off = min(dt_ms, max(0.0, spike_off))
did_spike = True
last_step_i = step_idx + 1
last_offset_i = spike_off
is_refr_i = True
u_i = u_reset_i
spike_time_i = t_ms + dt_ms - spike_off
def _emit_instant_spike(spike_off_ms: float):
nonlocal did_spike, last_step_i, last_offset_i, is_refr_i, u_i, spike_time_i
spike_off = min(dt_ms, max(0.0, spike_off_ms))
did_spike = True
last_step_i = step_idx + 1
last_offset_i = spike_off
is_refr_i = True
u_i = u_reset_i
spike_time_i = t_ms + dt_ms - spike_off
# Super-threshold at step start: spike at t + epsilon.
if (not is_refr_i) and (u_i >= u_th_i):
_emit_instant_spike(dt_ms * (1.0 - eps))
local_events = [(off, w[idx], False) for off, w in parsed_events]
if is_refr_i and (step_idx + 1 - last_step_i == refr_steps_i):
local_events.append((last_offset_i, 0.0, True))
local_events.sort(key=lambda z: z[0], reverse=True)
if len(local_events) == 0:
if not is_refr_i:
_propagate(dt_ms)
if u_i < u_min_i:
u_i = u_min_i
if u_i >= u_th_i:
_emit_spike(0.0)
else:
t_cursor = dt_ms
for ev_offset, ev_weight, end_of_refract in local_events:
if is_refr_i:
t_cursor = ev_offset
if not end_of_refract:
if self.refractory_input:
expo = -(
((last_step_i - step_idx - 1) * dt_ms)
- (last_offset_i - ev_offset)
+ t_ref_i
) / tau_i
refr_buf_i += ev_weight * math.exp(expo)
else:
is_refr_i = False
if self.refractory_input:
u_i += refr_buf_i
refr_buf_i = 0.0
if u_i >= u_th_i:
_emit_instant_spike(t_cursor)
continue
_propagate(t_cursor - ev_offset)
t_cursor = ev_offset
if u_i >= u_th_i:
_emit_spike(t_cursor)
continue
u_i += ev_weight
if u_i >= u_th_i:
_emit_instant_spike(t_cursor)
if (not is_refr_i) and (t_cursor > 0.0):
_propagate(t_cursor)
if u_i >= u_th_i:
_emit_spike(0.0)
v_i = u_i + E_L[idx]
V_next[idx] = v_i
last_step_next[idx] = last_step_i
last_offset_next[idx] = last_offset_i
is_refractory_next[idx] = is_refr_i
refr_buffer_next[idx] = refr_buf_i
last_spike_time_next[idx] = spike_time_i
spike_mask[idx] = did_spike
V_for_spike[idx] = (E_L[idx] + u_th_i + 1e-12) if did_spike else v_i
self.V.value = V_next * u.mV
self.I_stim.value = new_i_stim * u.pA
self.last_spike_step.value = jnp.asarray(last_step_next, dtype=ditype)
self.last_spike_offset.value = last_offset_next * u.ms
self.is_refractory.value = jnp.asarray(is_refractory_next, dtype=bool)
self.refractory_spike_buffer.value = refr_buffer_next * u.mV
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time_next * u.ms)
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(self.is_refractory.value)
return self.get_spike(V_for_spike * u.mV)