# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable, Iterable
import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
from ._utils import is_tracer, propagator_exp
__all__ = [
'iaf_psc_exp_multisynapse',
]
class iaf_psc_exp_multisynapse(NESTNeuron):
r"""NEST-compatible ``iaf_psc_exp_multisynapse`` neuron model.
Current-based leaky integrate-and-fire neuron with an arbitrary number of
receptor-indexed exponential postsynaptic current channels.
Description
-----------
``iaf_psc_exp_multisynapse`` mirrors NEST
``models/iaf_psc_exp_multisynapse.{h,cpp}`` and generalizes
:class:`iaf_psc_exp` from two fixed excitatory/inhibitory channels to
``n_receptors`` independently parameterized current ports, each with its
own exponential decay time constant.
Each receptor ``k`` (1-based, NEST convention) carries its own decay
constant ``tau_syn[k-1]``. Synaptic weights are signed currents in pA;
positive values are depolarizing and negative values are hyperpolarizing.
**1. Continuous-Time Dynamics and Receptor States**
Define :math:`V_{\mathrm{rel}} = V_m - E_L`. For receptor :math:`k`, the
synaptic current decays exponentially:
.. math::
\frac{dI_k}{dt} = -\frac{I_k}{\tau_{\mathrm{syn},k}}.
The membrane equation couples all receptor currents additively:
.. math::
\frac{dV_{\mathrm{rel}}}{dt}
= -\frac{V_{\mathrm{rel}}}{\tau_m}
+ \frac{\sum_k I_k + I_e + I_0}{C_m},
where :math:`I_0` is the one-step delayed continuous-current buffer (NEST
ring-buffer semantics). Assumptions match NEST's current-based model:
additive receptor currents, constant parameters within one simulation step,
and fixed ``dt`` for exact propagator coefficients.
**2. Exact Discrete Propagator, Derivation Constraints, and Stability**
For step size :math:`h = dt` (ms), receptor currents are integrated
exactly:
.. math::
I_{k,n+1} = P_{11,k}\, I_{k,n} + w_{k,n},
\qquad P_{11,k} = e^{-h/\tau_{\mathrm{syn},k}},
where :math:`w_{k,n}` is the total weight arriving at receptor :math:`k`
during step :math:`n`.
The membrane update uses the exact propagator:
.. math::
V_{\mathrm{rel},n+1}
= P_{22}\, V_{\mathrm{rel},n}
+ P_{20}(I_e + I_{0,n})
+ \sum_k P_{21,k}\, I_{k,n},
with propagator coefficients
.. math::
P_{22} = e^{-h/\tau_m}, \qquad
P_{20} = \frac{\tau_m}{C_m}(1 - P_{22}),
.. math::
P_{21,k}
= \frac{\tau_{\mathrm{syn},k}\,\tau_m}
{C_m\,(\tau_m - \tau_{\mathrm{syn},k})}
\left(e^{-h/\tau_m} - e^{-h/\tau_{\mathrm{syn},k}}\right).
:func:`propagator_exp` (from ``_utils``) evaluates :math:`P_{21,k}` with a
singular-limit fallback :math:`(h / C_m)\,e^{-h/\tau_m}` when
:math:`\tau_{\mathrm{syn},k} \approx \tau_m`, preventing catastrophic
cancellation in the denominator :math:`(\tau_m - \tau_{\mathrm{syn},k})`.
Construction additionally rejects ``np.isclose(tau_syn, tau_m)`` to
preserve robust conditioning and avoid near-degenerate parameterizations.
**3. Update Order per Simulation Step (NEST Semantics)**
Per-step execution order:
1. Integrate membrane with exact propagator for neurons not refractory
(:math:`r = 0`).
2. Decrement refractory counters for refractory neurons (:math:`r > 0`).
3. Decay all receptor currents :math:`I_k` by :math:`P_{11,k}`.
4. Inject receptor-specific spike weights :math:`w_{k,n}`, including
default delta input mapped to receptor 1 when ``n_receptors > 0``.
5. Apply threshold test, hard reset, refractory assignment, record
spike time, and store buffered continuous current ``x`` for step
:math:`n+1`.
**4. Assumptions, Constraints, and Computational Implications**
- ``C_m > 0``, ``tau_m > 0``, all ``tau_syn > 0``,
``not isclose(tau_syn, tau_m)``, ``t_ref >= 0``, and
``V_reset < V_th`` are enforced at construction.
- ``update(x=...)`` uses one-step delayed current buffering: current
provided at step ``n`` contributes through ``i_const`` at step ``n+1``,
matching NEST ring-buffer event semantics.
- The update path is fully vectorized over ``self.varshape`` and scales
as :math:`O(\prod \mathrm{varshape} \times n\_receptors)` per call.
- Internal propagator arithmetic is performed in NumPy ``float64`` before
writing back to BrainUnit-typed states.
- When ``n_receptors == 0``, all spike event inputs are silently ignored.
Parameters
----------
in_size : Size
Population shape specification. Per-neuron parameters and state
variables are broadcast/initialized over ``self.varshape`` derived
from ``in_size``.
E_L : ArrayLike, optional
Resting potential :math:`E_L` in mV; scalar or array broadcastable to
``self.varshape``. Default is ``-70. * u.mV``.
C_m : ArrayLike, optional
Membrane capacitance :math:`C_m` in pF; broadcastable to
``self.varshape`` and strictly positive. Default is ``250. * u.pF``.
tau_m : ArrayLike, optional
Membrane time constant :math:`\tau_m` in ms; broadcastable and
strictly positive. Default is ``10. * u.ms``.
t_ref : ArrayLike, optional
Absolute refractory period :math:`t_\mathrm{ref}` in ms; broadcastable
and nonnegative. Converted to integer grid steps via
``ceil(t_ref / dt)``. Default is ``2. * u.ms``.
V_th : ArrayLike, optional
Spike threshold :math:`V_\mathrm{th}` in mV; broadcastable to
``self.varshape``. Default is ``-55. * u.mV``.
V_reset : ArrayLike, optional
Post-spike reset potential :math:`V_\mathrm{reset}` in mV;
broadcastable and constrained by ``V_reset < V_th`` elementwise.
Default is ``-70. * u.mV``.
tau_syn : ArrayLike, optional
Synaptic decay constants in ms for all receptor ports. Converted to a
1-D ``float64`` array of shape ``(n_receptors,)`` via
``np.asarray(...).reshape(-1)``. Every entry must be strictly
positive and must not be numerically equal to ``tau_m`` under
``np.isclose``. The number of entries defines ``n_receptors``.
Default is ``(2.0,) * u.ms`` (one receptor).
I_e : ArrayLike, optional
Constant injected current :math:`I_e` in pA; scalar or array
broadcastable to ``self.varshape``. Default is ``0. * u.pA``.
V_initializer : Callable, optional
Initializer for membrane state ``V`` used by :meth:`init_state`.
Default is ``braintools.init.Constant(-70. * u.mV)``.
spk_fun : Callable, optional
Surrogate spike function used by :meth:`get_spike` and returned by
:meth:`update`. Default is ``braintools.surrogate.ReluGrad()``.
spk_reset : str, optional
Reset policy inherited from :class:`~brainpy_state._base.Neuron`.
``'hard'`` reproduces NEST hard reset behavior. Default is ``'hard'``.
ref_var : bool, optional
If ``True``, allocates optional boolean state ``self.refractory`` for
external refractory inspection. Default is ``False``.
name : str or None, optional
Optional node name passed to the parent module. Default is ``None``.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to model symbols
:header-rows: 1
:widths: 17 25 15 20 43
* - Parameter
- Type / shape / unit
- Default
- Math symbol
- Semantics
* - ``in_size``
- :class:`~brainstate.typing.Size`; scalar or tuple
- required
- --
- Defines population/state shape ``self.varshape``.
* - ``E_L``
- ArrayLike, broadcastable to ``self.varshape`` (mV)
- ``-70. * u.mV``
- :math:`E_L`
- Leak reversal (resting) potential.
* - ``C_m``
- ArrayLike, broadcastable (pF), ``> 0``
- ``250. * u.pF``
- :math:`C_m`
- Membrane capacitance in subthreshold integration.
* - ``tau_m``
- ArrayLike, broadcastable (ms), ``> 0``
- ``10. * u.ms``
- :math:`\tau_m`
- Membrane leak time constant.
* - ``t_ref``
- ArrayLike, broadcastable (ms), ``>= 0``
- ``2. * u.ms``
- :math:`t_\mathrm{ref}`
- Absolute refractory duration in physical time.
* - ``V_th`` and ``V_reset``
- ArrayLike, broadcastable (mV), with ``V_reset < V_th``
- ``-55. * u.mV``, ``-70. * u.mV``
- :math:`V_\mathrm{th}`, :math:`V_\mathrm{reset}`
- Threshold and post-spike reset levels.
* - ``tau_syn``
- ArrayLike, flattened to ``(n_receptors,)`` (ms), each ``> 0`` and
not ``isclose`` to ``tau_m``
- ``(2.0,) * u.ms``
- :math:`\tau_{\mathrm{syn},k}`
- Receptor-specific exponential PSC decay constants; number of
entries defines ``n_receptors``.
* - ``I_e``
- ArrayLike, broadcastable (pA)
- ``0. * u.pA``
- :math:`I_e`
- Constant current added each update step.
* - ``V_initializer``
- Callable
- ``Constant(-70. * u.mV)``
- --
- Initializer for membrane state ``V``.
* - ``spk_fun``
- Callable
- ``ReluGrad()``
- --
- Surrogate nonlinearity used for spike output.
* - ``spk_reset``
- str
- ``'hard'``
- --
- Reset mode from :class:`~brainpy_state._base.Neuron`.
* - ``ref_var``
- bool
- ``False``
- --
- If ``True``, exposes boolean state ``self.refractory``.
* - ``name``
- str | None
- ``None``
- --
- Optional node name.
Raises
------
ValueError
Raised at initialization or update time if any of the following holds:
- ``V_reset >= V_th``.
- ``C_m <= 0``, ``tau_m <= 0``, any ``tau_syn <= 0``, or ``t_ref < 0``.
- Any ``tau_syn`` is numerically equal to ``tau_m`` under
``np.isclose``.
- A spike event receptor index is outside ``[1, n_receptors]``.
TypeError
If parameters or inputs are not unit-compatible with expected
conversions (mV, ms, pF, pA).
KeyError
If simulation context entries (for example ``t`` or ``dt``) are
missing when :meth:`update` is called.
AttributeError
If :meth:`update` is called before :meth:`init_state` creates required
state holders.
Attributes
----------
V : brainstate.HiddenState
Membrane potential in mV; shape ``self.varshape``.
i_syn : brainstate.ShortTermState
Per-receptor synaptic currents in pA; shape
``self.varshape + (n_receptors,)``.
i_const : brainstate.ShortTermState
Buffered continuous current (pA) applied on the next simulation step.
Shape ``self.varshape``.
refractory_step_count : brainstate.ShortTermState
Integer countdown of remaining refractory steps (``jnp.int32``).
Shape ``self.varshape``.
last_spike_time : brainstate.ShortTermState
Simulation time of the most recent spike (ms). Shape
``self.varshape``.
refractory : brainstate.ShortTermState
Boolean refractory mask; only present when ``ref_var=True``.
Notes
-----
- This implementation uses exact (analytical) integration of the linear
subthreshold ODE via pre-computed propagator coefficients, matching
NEST's update precision for fixed-step simulation.
- Continuous current input ``x`` is combined with ``I_e`` and any
additional current sources registered via :meth:`sum_current_inputs`;
the combined value is buffered one step (NEST ring-buffer semantics).
- Spike weights from ``spike_events`` and ``sum_delta_inputs`` are signed
currents in pA: positive for depolarizing, negative for hyperpolarizing
receptors.
- Default delta input from ``sum_delta_inputs`` is routed to receptor 1
when ``n_receptors > 0``, replicating NEST default port behavior.
- If ``n_receptors == 0``, all spike event inputs are silently ignored and
``sum_delta_inputs`` is discarded.
Examples
--------
.. code-block:: python
>>> import brainstate
>>> import saiunit as u
>>> from brainpy_state._nest.iaf_psc_exp_multisynapse import (
... iaf_psc_exp_multisynapse,
... )
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neu = iaf_psc_exp_multisynapse(
... in_size=2,
... tau_syn=(2.0, 8.0) * u.ms,
... I_e=180.0 * u.pA,
... )
... neu.init_state()
... with brainstate.environ.context(t=0.0 * u.ms):
... spk = neu.update(
... spike_events=[{'receptor_type': 2, 'weight': 35.0 * u.pA}]
... )
... _ = spk.shape
.. code-block:: python
>>> import brainstate
>>> import saiunit as u
>>> from brainpy_state._nest.iaf_psc_exp_multisynapse import (
... iaf_psc_exp_multisynapse,
... )
>>> with brainstate.environ.context(dt=0.1 * u.ms):
... neu = iaf_psc_exp_multisynapse(in_size=1, tau_syn=(2.0,) * u.ms)
... neu.init_state()
... with brainstate.environ.context(t=0.0 * u.ms):
... _ = neu.update(x=250.0 * u.pA)
... with brainstate.environ.context(t=0.1 * u.ms):
... spk_next = neu.update()
... _ = spk_next
References
----------
.. [1] Rotter S, Diesmann M (1999). Exact simulation of time-invariant
linear systems with applications to neuronal modeling. Biological
Cybernetics 81:381-402.
DOI: https://doi.org/10.1007/s004220050570
.. [2] Diesmann M, Gewaltig M-O, Rotter S, Aertsen A (2001). State space
analysis of synchronous spiking in cortical neural networks.
Neurocomputing 38-40:565-571.
DOI: https://doi.org/10.1016/S0925-2312(01)00409-X
.. [3] Morrison A, Straube S, Plesser HE, Diesmann M (2007). Exact
subthreshold integration with continuous spike times in discrete
time neural network simulations. Neural Computation 19(1):47-79.
DOI: https://doi.org/10.1162/neco.2007.19.1.47
See Also
--------
iaf_psc_exp : LIF with two fixed exponential PSC channels (exc/inh)
iaf_psc_alpha_multisynapse : Multisynapse variant with alpha-shaped PSCs
iaf_psc_delta : LIF neuron with delta-function PSCs (voltage-jump synapses)
LIF : Leaky integrate-and-fire (brainpy parameterization)
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
E_L: ArrayLike = -70. * u.mV,
C_m: ArrayLike = 250. * u.pF,
tau_m: ArrayLike = 10. * u.ms,
t_ref: ArrayLike = 2. * u.ms,
V_th: ArrayLike = -55. * u.mV,
V_reset: ArrayLike = -70. * u.mV,
tau_syn: ArrayLike = (2.0,) * u.ms,
I_e: ArrayLike = 0. * u.pA,
V_initializer: Callable = braintools.init.Constant(-70. * u.mV),
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
ref_var: bool = False,
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
self.E_L = braintools.init.param(E_L, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.tau_m = braintools.init.param(tau_m, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.V_th = braintools.init.param(V_th, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
dftype = brainstate.environ.dftype()
self.tau_syn = np.asarray(u.math.asarray(tau_syn / u.ms), dtype=dftype).reshape(-1)
self.V_initializer = V_initializer
self.ref_var = ref_var
self._validate_parameters()
# Pre-compute refractory step count (matches aeif_cond_alpha pattern).
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
@property
def n_receptors(self):
r"""Number of independent synaptic receptor ports.
Returns
-------
out : int
Length of ``self.tau_syn``; equals ``len(tau_syn)`` as supplied
at construction.
"""
return int(self.tau_syn.size)
def _validate_parameters(self):
r"""Check parameter constraints and raise ``ValueError`` on violation.
Validates the following conditions (all checked at construction time):
- ``V_reset < V_th`` elementwise.
- ``C_m > 0`` elementwise.
- ``tau_m > 0`` elementwise.
- All entries in ``tau_syn > 0``.
- No entry in ``tau_syn`` is numerically equal to ``tau_m`` under
``np.isclose`` (prevents near-singular propagator evaluation).
- ``t_ref >= 0`` elementwise.
Raises
------
ValueError
On the first violated constraint, with a descriptive message.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.V_reset, self.C_m)):
return
if np.any(self.V_reset >= self.V_th):
raise ValueError('Reset potential must be smaller than threshold.')
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be > 0.')
if np.any(self.tau_m <= 0.0 * u.ms):
raise ValueError('Membrane time constant must be strictly positive.')
if np.any(self.tau_syn <= 0.0):
raise ValueError('All synaptic time constants must be strictly positive.')
tau_m_ms = self.tau_m / u.ms
if np.any(np.isclose(self.tau_syn, tau_m_ms)):
raise ValueError('Membrane and synapse time constants must differ.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory time must not be negative.')
[docs]
def init_state(self, **kwargs):
r"""Initialize membrane potential and all synaptic/refractory states.
Parameters
----------
**kwargs : Any
Unused compatibility arguments; accepted for interface consistency
with other nodes.
Raises
------
ValueError
If ``V_initializer`` output cannot be broadcast to the target
state shape.
TypeError
If initializer values are incompatible with required
numeric/unit conversions.
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
V = braintools.init.param(self.V_initializer, self.varshape)
self.V = brainstate.HiddenState(V)
self.i_syn = brainstate.ShortTermState(
u.math.zeros(self.varshape + (self.n_receptors,), dtype=dftype) * u.pA
)
self.i_const = brainstate.ShortTermState(
u.math.zeros(self.varshape, dtype=dftype) * u.pA
)
self.last_spike_time = brainstate.ShortTermState(
u.math.full(self.varshape, -1e7 * u.ms)
)
self.refractory_step_count = brainstate.ShortTermState(
u.math.full(self.varshape, 0, dtype=ditype)
)
if self.ref_var:
refractory = braintools.init.param(braintools.init.Constant(False), self.varshape)
self.refractory = brainstate.ShortTermState(refractory)
self._precompute_propagators()
def _precompute_propagators(self):
"""Pre-compute NEST propagator coefficients from dt and model parameters.
Called once during ``init_state`` so that ``update`` never needs to
recompute exponentials each step and remains JIT-compatible.
"""
dt = brainstate.environ.get_dt()
h = float(u.math.asarray(dt / u.ms))
dftype = brainstate.environ.dftype()
tau_m_ms = np.asarray(u.get_mantissa(self.tau_m / u.ms), dtype=np.float64)
C_m_pF = np.asarray(u.get_mantissa(self.C_m / u.pF), dtype=np.float64)
# Membrane propagators.
P22 = np.exp(-h / tau_m_ms)
self._P22 = P22.astype(dftype)
self._P20 = (tau_m_ms / C_m_pF * (1.0 - P22)).astype(dftype)
# Synaptic decay.
self._P11_syn = np.exp(-h / self.tau_syn).astype(dftype)
# Per-receptor membrane coupling.
P21_list = []
for tau_s in self.tau_syn:
P21_list.append(
propagator_exp(
tau_s * np.ones(self.varshape), tau_m_ms, C_m_pF, h
).astype(dftype)
)
self._P21_syn = np.stack(P21_list, axis=-1)
# Pre-compute constant voltage and current values.
self._E_L_mV = np.asarray(u.get_mantissa(self.E_L / u.mV), dtype=dftype)
self._theta_mV = np.asarray(u.get_mantissa((self.V_th - self.E_L) / u.mV), dtype=dftype)
self._V_reset_rel_mV = np.asarray(u.get_mantissa((self.V_reset - self.E_L) / u.mV), dtype=dftype)
self._I_e_pA = np.asarray(u.get_mantissa(self.I_e / u.pA), dtype=dftype)
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Evaluate surrogate spike activation for a voltage tensor.
Scales the voltage relative to threshold and reset to compute a
dimensionless argument passed to the surrogate nonlinearity
``self.spk_fun``:
.. math::
\text{out} = \mathrm{spk\_fun}\!\left(
\frac{V - V_\mathrm{th}}{V_\mathrm{th} - V_\mathrm{reset}}
\right).
Parameters
----------
V : ArrayLike or None, optional
Membrane voltage in mV, broadcast-compatible with
``self.varshape``. If ``None``, ``self.V.value`` is used.
Returns
-------
out : dict
Surrogate spike output from ``self.spk_fun`` with the same shape
as ``V`` (or ``self.V.value`` when ``V is None``). Positive values
indicate a spike; the argument to ``spk_fun`` is positive when
:math:`V > V_\mathrm{th}`.
Raises
------
TypeError
If ``V`` cannot participate in arithmetic with membrane
parameters due to incompatible dtype or unit.
"""
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
return self.spk_fun(v_scaled)
def _parse_spike_events(self, spike_events: Iterable, v_shape):
r"""Parse spike event descriptors into a per-receptor weight array.
Converts a heterogeneous iterable of spike events into a contiguous
``float64`` NumPy array that can be added directly to ``i_syn``.
Parameters
----------
spike_events : iterable or None
Events to parse. Each entry must be one of:
- A ``(receptor_type, weight)`` tuple where ``receptor_type`` is
a 1-based integer in ``[1, n_receptors]`` and ``weight`` is a
scalar or array in pA broadcastable to ``v_shape``.
- A ``dict`` with keys ``'receptor_type'`` (or ``'receptor'``)
and ``'weight'``.
Multiple events for the same receptor are accumulated additively.
``None`` is treated as an empty sequence.
v_shape : tuple of int
Shape of the neuron population state (``self.V.value.shape``).
Returns
-------
out : np.ndarray
Array of shape ``v_shape + (n_receptors,)`` with dtype
``float64``. Entry ``[..., k]`` is the total weight (in pA)
arriving at receptor ``k+1`` this step.
Raises
------
ValueError
If any ``receptor_type`` is outside ``[1, n_receptors]``.
TypeError
If a weight value is not unit-compatible with pA.
"""
dftype = brainstate.environ.dftype()
out = np.zeros(v_shape + (self.n_receptors,), dtype=dftype)
if spike_events is None:
return out
for ev in spike_events:
if isinstance(ev, dict):
receptor = int(ev.get('receptor_type', ev.get('receptor', 1)))
weight = ev.get('weight', 0.0)
else:
receptor, weight = ev
receptor = int(receptor)
if receptor < 1 or receptor > self.n_receptors:
raise ValueError(f'Receptor type {receptor} out of range [1, {self.n_receptors}].')
w_np = np.asarray(u.math.asarray(weight / u.pA), dtype=dftype)
out[..., receptor - 1] += np.broadcast_to(w_np, v_shape)
return out
[docs]
def update(self, x=0. * u.pA, spike_events=None, w_by_rec=None):
r"""Advance the neuron state by one simulation step.
Executes the full NEST-compatible per-step update: exact membrane
propagation for non-refractory neurons, receptor current decay and
spike injection, threshold/reset/refractory logic, and buffered
current storage.
Parameters
----------
x : ArrayLike, optional
Continuous current input in pA for this step. ``x`` is accumulated
through :meth:`sum_current_inputs` (which also adds any registered
projection currents) and stored in ``i_const`` for use on the
**next** step, matching NEST ring-buffer semantics. Scalar or
array broadcastable to ``self.varshape``. Default is
``0. * u.pA``.
spike_events : iterable or None, optional
Receptor-indexed spike weight events to inject this step. Each
entry must be either:
- A ``(receptor_type, weight)`` tuple where ``receptor_type`` is
a 1-based integer in ``[1, n_receptors]`` and ``weight`` is a
scalar or array in pA broadcastable to ``self.varshape``.
- A ``dict`` with keys ``'receptor_type'`` (or ``'receptor'``)
and ``'weight'``.
Multiple events for the same receptor are accumulated additively.
``None`` injects no receptor spike events. Default is ``None``.
Ignored when ``w_by_rec`` is provided.
w_by_rec : array-like or None, optional
Pre-computed per-receptor spike weights in pA (dimensionless),
shape broadcastable to ``self.varshape + (n_receptors,)``. When
provided, bypasses ``spike_events`` parsing and
``sum_delta_inputs``, making the update JIT-compatible for use
inside ``brainstate.transform.for_loop``. Default is ``None``.
Returns
-------
out : jax.Array
Surrogate spike output from :meth:`get_spike` with shape
``self.V.value.shape``. For neurons that fire this step, the
voltage argument to :meth:`get_spike` is nudged
:math:`\theta + E_L + 10^{-12}` mV (above threshold) to ensure a
positive surrogate activation is returned even after the hard
voltage reset.
Raises
------
ValueError
If any receptor index in ``spike_events`` is outside
``[1, n_receptors]``.
KeyError
If the simulation environment context does not supply ``t`` or
``dt``.
AttributeError
If state variables are missing because :meth:`init_state` has not
been called before ``update``.
TypeError
If ``x`` or stored states are not unit-compatible with expected
pA / mV arithmetic.
ValueError
If provided inputs cannot be broadcast to the internal state
shape.
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
ditype = brainstate.environ.ditype()
# Read state variables with their natural units.
V = self.V.value # mV
i_syn = self.i_syn.value # pA, shape varshape + (n_receptors,)
i_const = self.i_const.value # pA
r = self.refractory_step_count.value # int
# Use pre-computed constants (avoids recomputing exponentials each step).
I_e_pA = self._I_e_pA
E_L_mV = self._E_L_mV
theta_mV = self._theta_mV
V_reset_rel_mV = self._V_reset_rel_mV
P22 = self._P22
P20 = self._P20
P11_syn = self._P11_syn
P21_syn = self._P21_syn
# Strip units (JAX-compatible via u.get_mantissa).
i_const_pA = u.get_mantissa(i_const / u.pA)
V_rel_mV = u.get_mantissa((V - self.E_L) / u.mV)
i_syn_pA = u.get_mantissa(i_syn / u.pA)
# Build per-receptor spike weight array.
if w_by_rec is None:
# Python-level path: parses spike_events dicts/tuples (not JIT-compatible).
dftype = brainstate.environ.dftype()
v_shape = self.V.value.shape
w_val = self._parse_spike_events(spike_events, v_shape)
w_delta = np.asarray(
u.get_mantissa(self.sum_delta_inputs(0. * u.pA) / u.pA),
dtype=dftype,
)
w_delta = np.broadcast_to(w_delta, v_shape)
if self.n_receptors > 0:
w_val = w_val.copy()
w_val[..., 0] += w_delta
else:
# JAX-array path: caller supplies pre-computed weights, JIT-compatible.
w_val = w_by_rec
# Current input for next step (one-step delay).
new_i_const = self.sum_current_inputs(x, self.V.value) # pA
# 1. Membrane integration for non-refractory neurons.
not_refractory = r == 0
V_candidate = (
V_rel_mV * P22
+ (I_e_pA + i_const_pA) * P20
+ jnp.sum(P21_syn * i_syn_pA, axis=-1)
)
V_rel_mV = jnp.where(not_refractory, V_candidate, V_rel_mV)
# 2. Decrement refractory counters.
r = jnp.where(not_refractory, r, r - 1)
# 3. Decay receptor currents and inject spike weights.
i_syn_pA = i_syn_pA * P11_syn
i_syn_pA = i_syn_pA + w_val
# 4. Threshold test, reset, refractory assignment.
spike_cond = V_rel_mV >= theta_mV
r = jnp.where(spike_cond, jnp.asarray(u.get_mantissa(self.ref_count), dtype=ditype), r)
V_before_reset = V_rel_mV
V_rel_mV = jnp.where(spike_cond, V_reset_rel_mV, V_rel_mV)
# Write back state.
self.V.value = (V_rel_mV + E_L_mV) * u.mV
self.i_syn.value = i_syn_pA * u.pA
self.i_const.value = new_i_const + u.math.zeros(self.varshape) * u.pA
self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype)
last_spike_time = u.math.where(spike_cond, t + dt, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0)
V_out = jnp.where(spike_cond, theta_mV + E_L_mV + 1e-12, V_before_reset + E_L_mV)
return self.get_spike(V_out * u.mV)