# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable, Iterable
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
from ._utils import is_tracer
from .iaf_psc_alpha import iaf_psc_alpha
__all__ = [
'iaf_psc_alpha_multisynapse',
]
class iaf_psc_alpha_multisynapse(NESTNeuron):
r"""NEST-compatible ``iaf_psc_alpha_multisynapse`` neuron model.
Current-based leaky integrate-and-fire neuron with an arbitrary number of
receptor-indexed alpha-shaped synaptic current channels.
Description
-----------
``iaf_psc_alpha_multisynapse`` mirrors NEST
``models/iaf_psc_alpha_multisynapse.{h,cpp}`` and generalizes
:class:`iaf_psc_alpha` from two fixed excitatory/inhibitory channels to
``n_receptors`` independently parameterized current ports.
Each receptor ``k`` (1-based, NEST convention) carries its own alpha time
constant ``tau_syn[k-1]``. Synaptic weights are signed currents in pA;
positive values are depolarizing and negative values are hyperpolarizing.
**1. Continuous-Time Dynamics and Receptor States**
Membrane dynamics are
.. math::
\frac{dV_m}{dt} = -\frac{V_m - E_L}{\tau_m}
+ \frac{\sum_k I_k + I_e + I_0}{C_m},
where :math:`I_0` is the one-step delayed continuous current buffer
(NEST ring-buffer semantics) and :math:`I_k` are the per-receptor alpha
currents.
For each receptor :math:`k`, the alpha current kernel is represented by a
two-state linear system (``y1[k]``, ``y2[k]``):
.. math::
\frac{d\,y1_k}{dt} = -\frac{y1_k}{\tau_{\mathrm{syn},k}}, \qquad
\frac{d\,y2_k}{dt} = y1_k - \frac{y2_k}{\tau_{\mathrm{syn},k}}.
The effective synaptic current for receptor :math:`k` is :math:`I_k = y2_k`.
An incoming spike with weight :math:`w_k` (pA) is injected into ``y1[k]``
with the NEST alpha normalization factor:
.. math::
y1_k \leftarrow y1_k + \frac{e}{\tau_{\mathrm{syn},k}} w_k.
This normalization ensures that a single spike with weight :math:`w_k`
produces a current kernel that peaks exactly at :math:`w_k` when
:math:`t = \tau_{\mathrm{syn},k}`:
.. math::
I_k(t) = w_k \frac{t}{\tau_{\mathrm{syn},k}}
\exp\!\left(1 - \frac{t}{\tau_{\mathrm{syn},k}}\right), \quad t \ge 0.
**2. Exact Discrete Propagator, Derivation Constraints, and Stability**
With fixed step :math:`h = dt`, exact matrix propagation of the linear
subsystem is used. For each receptor :math:`k`:
.. math::
y1_{k,n+1} = P_{11,k}\,y1_{k,n} + \frac{e}{\tau_{\mathrm{syn},k}} w_{k,n},
.. math::
y2_{k,n+1} = P_{21,k}\,y1_{k,n} + P_{22,k}\,y2_{k,n},
where :math:`P_{11,k} = P_{22,k} = e^{-h/\tau_{\mathrm{syn},k}}` and
:math:`P_{21,k} = h\,e^{-h/\tau_{\mathrm{syn},k}}`.
Membrane relative voltage :math:`y_3 = V_m - E_L` is updated as
.. math::
y_{3,n+1} = P_{33}\,y_{3,n} + P_{30}(I_{0,n} + I_e)
+ \sum_k \left(P_{31,k}\,y1_{k,n} + P_{32,k}\,y2_{k,n}\right),
with :math:`P_{33} = e^{-h/\tau_m}` and
:math:`P_{30} = \tau_m(1 - e^{-h/\tau_m})/C_m`.
Coefficients :math:`P_{31,k}`, :math:`P_{32,k}` are computed via
:meth:`iaf_psc_alpha._alpha_propagator_p31_p32`, which applies the stable
near-singular limit for :math:`\tau_m \approx \tau_{\mathrm{syn},k}`:
.. math::
P_{32}^{\mathrm{sing}} = \frac{h}{C_m} e^{-h/\tau_m}, \qquad
P_{31}^{\mathrm{sing}} = \frac{h^2}{2C_m} e^{-h/\tau_m},
preventing catastrophic cancellation when :math:`\tau_m = \tau_{\mathrm{syn},k}`.
**3. Update Order per Simulation Step (NEST Semantics)**
Per-step execution order:
1. Integrate membrane with exact propagator for neurons not refractory
(:math:`r = 0`).
2. Decrement refractory counters for neurons currently refractory
(:math:`r > 0`).
3. Propagate all receptor alpha states ``y1``, ``y2`` forward by one step.
4. Inject receptor-specific spike weights into ``y1``, including default
delta input mapped to receptor 1 when ``n_receptors > 0``.
5. Apply threshold test, hard reset, refractory assignment, and spike
emission.
6. Store buffered continuous current for the next step.
**4. Assumptions, Constraints, and Computational Implications**
- ``C_m > 0``, ``tau_m > 0``, all ``tau_syn > 0``, ``t_ref >= 0``, and
``V_reset < V_th`` are enforced at construction.
- ``update(x=...)`` uses one-step delayed current buffering: current
provided at step ``n`` contributes through ``i_const`` at step ``n+1``,
matching NEST ring-buffer event semantics.
- The update path is fully vectorized over ``self.varshape`` and scales as
:math:`O(\prod \mathrm{varshape} \times n\_receptors)` per call.
- Internal propagator arithmetic is performed in NumPy ``float64`` before
writing back to BrainUnit-typed states.
- When ``n_receptors == 0``, all spike event inputs are silently ignored.
Parameters
----------
in_size : Size
Population shape specification. Per-neuron parameters and state
variables are broadcast/initialized over ``self.varshape`` derived
from ``in_size``.
E_L : ArrayLike, optional
Resting potential :math:`E_L` in mV; scalar or array broadcastable to
``self.varshape``. Default is ``-70. * u.mV``.
C_m : ArrayLike, optional
Membrane capacitance :math:`C_m` in pF; broadcastable to
``self.varshape`` and strictly positive. Default is ``250. * u.pF``.
tau_m : ArrayLike, optional
Membrane time constant :math:`\tau_m` in ms; broadcastable and
strictly positive. Default is ``10. * u.ms``.
t_ref : ArrayLike, optional
Absolute refractory period :math:`t_\mathrm{ref}` in ms; broadcastable
and nonnegative. Converted to integer grid steps by
``ceil(t_ref / dt)``. Default is ``2. * u.ms``.
V_th : ArrayLike, optional
Spike threshold :math:`V_\mathrm{th}` in mV; broadcastable to
``self.varshape``. Default is ``-55. * u.mV``.
V_reset : ArrayLike, optional
Post-spike reset potential :math:`V_\mathrm{reset}` in mV;
broadcastable and constrained by ``V_reset < V_th`` elementwise.
Default is ``-70. * u.mV``.
tau_syn : ArrayLike, optional
Receptor alpha time constants in ms. Values are converted to a
1-D ``float64`` array with shape ``(n_receptors,)``; every entry must
be strictly positive. The number of entries defines ``n_receptors``.
Default is ``(2.0,) * u.ms`` (one receptor).
I_e : ArrayLike, optional
Constant injected current :math:`I_e` in pA; scalar or array
broadcastable to ``self.varshape``. Default is ``0. * u.pA``.
V_min : ArrayLike or None, optional
Optional lower clamp :math:`V_\mathrm{min}` in mV applied to the
membrane candidate update before thresholding. ``None`` disables
clamping. Default is ``None``.
V_initializer : Callable, optional
Initializer for membrane state ``V`` used by :meth:`init_state`.
Default is ``braintools.init.Constant(-70. * u.mV)``.
spk_fun : Callable, optional
Surrogate spike function used by :meth:`get_spike` and returned by
:meth:`update`. Default is ``braintools.surrogate.ReluGrad()``.
spk_reset : str, optional
Reset policy inherited from :class:`~brainpy_state._base.Neuron`.
``'hard'`` reproduces NEST hard reset behavior. Default is ``'hard'``.
ref_var : bool, optional
If ``True``, allocates optional boolean state ``self.refractory`` for
external refractory inspection. Default is ``False``.
name : str or None, optional
Optional node name.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to model symbols
:header-rows: 1
:widths: 17 25 15 20 43
* - Parameter
- Type / shape / unit
- Default
- Math symbol
- Semantics
* - ``in_size``
- :class:`~brainstate.typing.Size`; scalar or tuple
- required
- --
- Defines population/state shape ``self.varshape``.
* - ``E_L``
- ArrayLike, broadcastable to ``self.varshape`` (mV)
- ``-70. * u.mV``
- :math:`E_L`
- Leak reversal (resting) potential.
* - ``C_m``
- ArrayLike, broadcastable (pF), ``> 0``
- ``250. * u.pF``
- :math:`C_m`
- Membrane capacitance in subthreshold integration.
* - ``tau_m``
- ArrayLike, broadcastable (ms), ``> 0``
- ``10. * u.ms``
- :math:`\tau_m`
- Membrane leak time constant.
* - ``t_ref``
- ArrayLike, broadcastable (ms), ``>= 0``
- ``2. * u.ms``
- :math:`t_\mathrm{ref}`
- Absolute refractory duration in physical time.
* - ``V_th`` and ``V_reset``
- ArrayLike, broadcastable (mV), with ``V_reset < V_th``
- ``-55. * u.mV``, ``-70. * u.mV``
- :math:`V_\mathrm{th}`, :math:`V_\mathrm{reset}`
- Threshold and post-spike reset levels.
* - ``tau_syn``
- ArrayLike, flattened to ``(n_receptors,)`` (ms), each ``> 0``
- ``(2.0,) * u.ms``
- :math:`\tau_{\mathrm{syn},k}`
- Receptor-specific alpha time constants; length defines
``n_receptors``.
* - ``I_e``
- ArrayLike, broadcastable (pA)
- ``0. * u.pA``
- :math:`I_e`
- Constant current added each update step.
* - ``V_min``
- ArrayLike broadcastable (mV) or ``None``
- ``None``
- :math:`V_\mathrm{min}`
- Optional lower clamp on candidate membrane voltage.
* - ``V_initializer``
- Callable
- ``Constant(-70. * u.mV)``
- --
- Initializer for membrane state ``V``.
* - ``spk_fun``
- Callable
- ``ReluGrad()``
- --
- Surrogate nonlinearity used for spike output.
* - ``spk_reset``
- str
- ``'hard'``
- --
- Reset mode from :class:`~brainpy_state._base.Neuron`.
* - ``ref_var``
- bool
- ``False``
- --
- If ``True``, exposes boolean state ``self.refractory``.
* - ``name``
- str | None
- ``None``
- --
- Optional node name.
Raises
------
ValueError
Raised at initialization or update time if any of the following holds:
- ``C_m <= 0``, ``tau_m <= 0``, any ``tau_syn <= 0``, ``t_ref < 0``,
or ``V_reset >= V_th``.
- A spike event receptor index is outside ``[1, n_receptors]``.
TypeError
If parameters or inputs are not unit-compatible with the expected
conversions (mV, ms, pF, pA).
KeyError
If simulation context entries (for example ``t`` or ``dt``) are
missing when :meth:`update` is called.
AttributeError
If :meth:`update` is called before :meth:`init_state` creates required
state holders.
Notes
-----
- State variables are ``V``, ``y1_syn``, ``y2_syn``, ``i_const``,
``refractory_step_count``, and ``last_spike_time``; ``refractory`` is
added only when ``ref_var=True``.
- Spike weights from ``spike_events`` and ``sum_delta_inputs`` are signed
currents in pA: positive for depolarizing, negative for hyperpolarizing
receptors. This differs from conductance-based multisynapse models where
weights must be non-negative.
- ``update(x=...)`` stores ``x`` into ``i_const`` for use on the next
step, matching NEST current-event buffering semantics.
- If ``n_receptors == 0``, all spike event inputs are silently ignored and
``sum_delta_inputs`` is discarded.
- Default delta input from ``sum_delta_inputs`` is routed to receptor 1
when ``n_receptors > 0``, replicating NEST default port behavior.
Examples
--------
.. code-block:: python
>>> import brainstate
>>> import saiunit as u
>>> from brainpy_state._nest.iaf_psc_alpha_multisynapse import (
... iaf_psc_alpha_multisynapse,
... )
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neu = iaf_psc_alpha_multisynapse(
... in_size=3,
... tau_syn=(1.5, 3.0) * u.ms,
... I_e=180.0 * u.pA,
... )
... neu.init_state()
... with brainstate.environ.context(t=0.0 * u.ms):
... spk = neu.update(
... spike_events=[{'receptor_type': 2, 'weight': 40.0 * u.pA}]
... )
... _ = spk.shape
.. code-block:: python
>>> import brainstate
>>> import saiunit as u
>>> from brainpy_state._nest.iaf_psc_alpha_multisynapse import (
... iaf_psc_alpha_multisynapse,
... )
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neu = iaf_psc_alpha_multisynapse(in_size=1, tau_syn=(2.0,) * u.ms)
... neu.init_state()
... with brainstate.environ.context(t=0.0 * u.ms):
... _ = neu.update(x=250.0 * u.pA)
... with brainstate.environ.context(t=0.1 * u.ms):
... spk_next = neu.update()
... _ = spk_next
References
----------
.. [1] NEST source: ``models/iaf_psc_alpha_multisynapse.h`` and
``models/iaf_psc_alpha_multisynapse.cpp``.
.. [2] Rotter S, Diesmann M (1999). Exact simulation of time-invariant
linear systems with applications to neuronal modeling. Biological
Cybernetics 81:381-402.
DOI: https://doi.org/10.1007/s004220050570
.. [3] Diesmann M, Gewaltig M-O, Rotter S, Aertsen A (2001). State space
analysis of synchronous spiking in cortical neural networks.
Neurocomputing 38-40:565-571.
DOI: https://doi.org/10.1016/S0925-2312(01)00409-X
.. [4] Morrison A, Straube S, Plesser HE, Diesmann M (2007). Exact
subthreshold integration with continuous spike times in discrete
time neural network simulations. Neural Computation 19(1):47-79.
DOI: https://doi.org/10.1162/neco.2007.19.1.47
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
E_L: ArrayLike = -70. * u.mV,
C_m: ArrayLike = 250. * u.pF,
tau_m: ArrayLike = 10. * u.ms,
t_ref: ArrayLike = 2. * u.ms,
V_th: ArrayLike = -55. * u.mV,
V_reset: ArrayLike = -70. * u.mV,
tau_syn: ArrayLike = (2.0,) * u.ms,
I_e: ArrayLike = 0. * u.pA,
V_min: ArrayLike = None,
V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
ref_var: bool = False,
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
self.E_L = braintools.init.param(E_L, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.tau_m = braintools.init.param(tau_m, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.V_th = braintools.init.param(V_th, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
self.V_min = None if V_min is None else braintools.init.param(V_min, self.varshape)
dftype = brainstate.environ.dftype()
self.tau_syn = np.asarray(u.math.asarray(tau_syn / u.ms), dtype=dftype).reshape(-1)
self.V_initializer = V_initializer
self.ref_var = ref_var
self._validate_parameters()
# Precompute refractory step count.
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
@property
def n_receptors(self):
return int(self.tau_syn.size)
def _validate_parameters(self):
"""Validate model parameters against NEST constraints.
Raises
------
ValueError
If parameter inequalities or positivity constraints are violated.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
return
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.tau_m <= 0.0 * u.ms):
raise ValueError('Membrane time constant must be strictly positive.')
if np.any(self.tau_syn <= 0.0):
raise ValueError('All synaptic time constants must be strictly positive.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory time must not be negative.')
if np.any(self.V_reset >= self.V_th):
raise ValueError('Reset potential must be smaller than threshold.')
[docs]
def init_state(self, **kwargs):
r"""Initialize runtime states for membrane, synaptic, and refractory variables.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
Raises
------
ValueError
If initializers cannot broadcast to ``self.varshape``.
TypeError
If initializer outputs are incompatible with expected unit/array
conversions for voltage, current, or integer refractory states.
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
V = braintools.init.param(self.V_initializer, self.varshape)
syn_shape = self.varshape + (self.n_receptors,)
self.V = brainstate.HiddenState(V)
self.y1_syn = brainstate.ShortTermState(u.math.zeros(syn_shape, dtype=dftype))
self.y2_syn = brainstate.ShortTermState(u.math.zeros(syn_shape, dtype=dftype) * u.pA)
self.i_const = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype))
self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype))
self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms))
if self.ref_var:
self.refractory = brainstate.ShortTermState(
braintools.init.param(braintools.init.Constant(False), self.varshape)
)
# Pre-compute propagator coefficients (constant for a given dt).
self._precompute_propagators()
def _precompute_propagators(self):
"""Pre-compute NEST propagator coefficients from dt and model parameters.
Called once during ``init_state`` so that ``update`` never needs to
recompute exponentials each step and remains JIT-compatible.
"""
dt = brainstate.environ.get_dt()
h = float(u.math.asarray(dt / u.ms))
dftype = brainstate.environ.dftype()
tau_m = np.asarray(u.get_mantissa(self.tau_m / u.ms), dtype=np.float64)
C_m = np.asarray(u.get_mantissa(self.C_m / u.pF), dtype=np.float64)
self._P11 = np.exp(-h / self.tau_syn).astype(dftype) # (n_receptors,)
self._P22 = self._P11
self._P21 = (h * self._P11).astype(dftype)
P33 = np.exp(-h / tau_m)
self._P33 = P33.astype(dftype) # varshape
self._P30 = ((1.0 - P33) * tau_m / C_m).astype(dftype)
P31_list = []
P32_list = []
for tau_s in self.tau_syn:
p31, p32 = iaf_psc_alpha._alpha_propagator_p31_p32(
tau_s * np.ones(self.varshape),
tau_m,
C_m,
h,
)
P31_list.append(p31.astype(dftype))
P32_list.append(p32.astype(dftype))
self._P31 = np.stack(P31_list, axis=-1) # varshape + (n_receptors,)
self._P32 = np.stack(P32_list, axis=-1)
self._psc_init = (np.e / self.tau_syn).astype(dftype) # (n_receptors,)
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Evaluate surrogate spike output for a voltage tensor.
Parameters
----------
V : ArrayLike or None, optional
Voltage input in mV, broadcast-compatible with ``self.varshape``.
If ``None``, uses current membrane state ``self.V.value``.
Returns
-------
out : dict
Surrogate spike output from ``self.spk_fun`` with the same shape
as ``V`` (or ``self.V.value`` when ``V is None``). The input to
``spk_fun`` is scaled as ``(V - V_th) / (V_th - V_reset)`` so
the surrogate activates positively for suprathreshold voltages.
"""
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
return self.spk_fun(v_scaled)
def _refractory_counts(self):
dt = brainstate.environ.get_dt()
ditype = brainstate.environ.ditype()
return u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
def _parse_spike_events(self, spike_events: Iterable, v_shape):
dftype = brainstate.environ.dftype()
out = np.zeros(v_shape + (self.n_receptors,), dtype=dftype)
if spike_events is None:
return out
for ev in spike_events:
if isinstance(ev, dict):
receptor = int(ev.get('receptor_type', ev.get('receptor', 1)))
weight = ev.get('weight', 0.0)
else:
receptor, weight = ev
receptor = int(receptor)
if receptor < 1 or receptor > self.n_receptors:
raise ValueError(f'Receptor type {receptor} out of range [1, {self.n_receptors}].')
w_np = np.asarray(u.math.asarray(weight / u.pA), dtype=dftype)
out[..., receptor - 1] += np.broadcast_to(w_np, v_shape)
return out
[docs]
def update(self, x=0. * u.pA, spike_events=None, w_by_rec=None):
r"""Advance the neuron by one simulation step.
Parameters
----------
x : ArrayLike, optional
Continuous current input in pA for this step. ``x`` is accumulated
through :meth:`sum_current_inputs` and stored in ``i_const`` for
use on the next call (one-step delayed buffering matching NEST
ring-buffer semantics). Default is ``0. * u.pA``.
spike_events : iterable or None, optional
Receptor-indexed spike weight events to inject this step. Each
entry must be either:
- A ``(receptor_type, weight)`` tuple where ``receptor_type`` is
a 1-based integer in ``[1, n_receptors]`` and ``weight`` is a
scalar or array in pA (broadcastable to ``self.varshape``).
- A ``dict`` with keys ``'receptor_type'`` (or ``'receptor'``)
and ``'weight'``.
Multiple events for the same receptor are accumulated additively.
``None`` injects no receptor spike events. Default is ``None``.
Ignored when ``w_by_rec`` is provided.
w_by_rec : array-like or None, optional
Pre-computed per-receptor spike weights in pA (dimensionless),
shape broadcastable to ``self.varshape + (n_receptors,)``. When
provided, bypasses ``spike_events`` parsing and
``sum_delta_inputs``, making the update JIT-compatible for use
inside ``brainstate.transform.for_loop``. Default is ``None``.
Returns
-------
out : jax.Array
Spike output tensor from :meth:`get_spike`, shape
``self.V.value.shape``. On threshold crossings, the voltage
presented to ``spk_fun`` is nudged above threshold by ``1e-12``
mV-equivalent to preserve positive surrogate activation.
Raises
------
ValueError
If any receptor index in ``spike_events`` is outside
``[1, n_receptors]``.
KeyError
If simulation context does not provide ``t`` or ``dt``.
AttributeError
If required states are missing because :meth:`init_state` was not
called.
TypeError
If ``x`` or stored states are not unit-compatible with expected
pA / mV conversions.
"""
t = brainstate.environ.get('t')
dt_q = brainstate.environ.get_dt()
ditype = brainstate.environ.ditype()
# Read state variables with their natural units.
V = self.V.value # mV
y1_syn = self.y1_syn.value # unitless JAX array, varshape + (n_rec,)
y2_syn = self.y2_syn.value # pA
i_const = self.i_const.value # pA
r = self.refractory_step_count.value # int
# Current input for next step (one-step delay).
new_i_const = self.sum_current_inputs(x, V) # pA
# Build per-receptor spike weight array (pA values, dimensionless).
if w_by_rec is None:
# Python-level path: parses spike_events dicts/tuples (not JIT-compatible).
dftype = brainstate.environ.dftype()
v_shape = self.varshape
w_val = self._parse_spike_events(spike_events, v_shape) # numpy
w_delta = np.asarray(
u.get_mantissa(self.sum_delta_inputs(0. * u.pA) / u.pA),
dtype=dftype,
)
w_delta = np.broadcast_to(w_delta, v_shape)
if self.n_receptors > 0:
w_val = w_val.copy()
w_val[..., 0] += w_delta
else:
# JAX-array path: caller supplies pre-computed weights, JIT-compatible.
w_val = w_by_rec # shape broadcastable to varshape + (n_receptors,)
# Strip units from state values using u.get_mantissa (JAX-compatible).
y1_val = y1_syn # already unitless
y2_val = u.get_mantissa(y2_syn / u.pA) # JAX array
i_const_val = u.get_mantissa(i_const / u.pA)
I_e_val = u.get_mantissa(self.I_e / u.pA)
V_rel = u.get_mantissa((V - self.E_L) / u.mV)
# Use pre-computed propagator coefficients.
P11 = self._P11
P21 = self._P21
P22 = self._P22
P30 = self._P30
P31 = self._P31
P32 = self._P32
P33 = self._P33
psc_init = self._psc_init
# 1) Membrane update for non-refractory neurons.
V_candidate = (
P30 * (i_const_val + I_e_val)
+ P33 * V_rel
+ jnp.sum(P31 * y1_val + P32 * y2_val, axis=-1)
)
if self.V_min is not None:
lower = u.get_mantissa((self.V_min - self.E_L) / u.mV)
V_candidate = u.math.maximum(V_candidate, lower)
not_refractory = r == 0
V_rel = jnp.where(not_refractory, V_candidate, V_rel)
r = jnp.where(not_refractory, r, r - 1)
# 2) Synaptic alpha state propagation.
y2_val = P21 * y1_val + P22 * y2_val
y1_val = y1_val * P11 + psc_init * w_val
# 3) Threshold test, reset, and refractory assignment.
theta_val = u.get_mantissa((self.V_th - self.E_L) / u.mV)
V_reset_val = u.get_mantissa((self.V_reset - self.E_L) / u.mV)
spike_cond = V_rel >= theta_val
r = jnp.where(
spike_cond,
jnp.asarray(u.get_mantissa(self.ref_count), dtype=ditype),
r,
)
V_before_reset = V_rel
V_rel = jnp.where(spike_cond, V_reset_val, V_rel)
# Write back state.
E_L_val = u.get_mantissa(self.E_L / u.mV)
self.V.value = (V_rel + E_L_val) * u.mV
self.y1_syn.value = y1_val
self.y2_syn.value = y2_val * u.pA
self.i_const.value = new_i_const + u.math.zeros(self.varshape) * u.pA
self.refractory_step_count.value = jnp.asarray(r, dtype=ditype)
last_spike_time = u.math.where(spike_cond, t + dt_q, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0)
V_out = jnp.where(spike_cond, theta_val + E_L_val + 1e-12, V_before_reset + E_L_val)
return self.get_spike(V_out * u.mV)