Source code for brainpy_state._nest.gif_psc_exp_multisynapse

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

r"""Current-based GIF neuron with multiple synaptic time constants.

This module implements ``gif_psc_exp_multisynapse``, the multisynapse
extension of :class:`gif_psc_exp`.  It is a faithful re-implementation of
the identically named NEST model
(``models/gif_psc_exp_multisynapse.{h,cpp}``), preserving update ordering,
exact (analytic) propagator integration, stochastic firing, and all default
parameter values.

The key difference from :class:`gif_psc_exp` is that instead of having two
fixed synaptic channels (excitatory and inhibitory), this model supports an
arbitrary number of receptor ports, each with its own exponential synaptic
time constant.  Incoming spike events specify which receptor port they
target (1-based indexing, as in NEST).

Mathematical model
------------------

Membrane potential ODE:

.. math::

   C_m \frac{dV}{dt} = -g_L (V - E_L)
       - \sum_j \eta_j(t)
       + \sum_k I_{\mathrm{syn},k}(t)
       + I_e + I_{\mathrm{stim}}(t)

Synaptic currents (one per receptor port *k*):

.. math::

   \frac{dI_{\mathrm{syn},k}}{dt} = -\frac{I_{\mathrm{syn},k}}{\tau_{\mathrm{syn},k}}

Spike-triggered currents (STC):

.. math::

   \tau_{\eta_j} \frac{d\eta_j}{dt} = -\eta_j, \qquad
   \eta_j \to \eta_j + q_{\eta_j} \;\text{on spike}

Spike-frequency adaptation (SFA) threshold:

.. math::

   V_T(t) = V_{T^*} + \sum_i \gamma_i(t), \qquad
   \tau_{\gamma_i} \frac{d\gamma_i}{dt} = -\gamma_i, \qquad
   \gamma_i \to \gamma_i + q_{\gamma_i} \;\text{on spike}

Stochastic spiking via exponential escape rate:

.. math::

   \lambda(t) = \lambda_0 \exp\!\bigl((V(t) - V_T(t)) / \Delta_V\bigr),
   \qquad P_{\text{spike}} = 1 - \exp(-\lambda \, dt)

References
----------
.. [1] Mensi S, Naud R, Pozzorini C, Avermann M, Petersen CC, Gerstner W
       (2012). Parameter extraction and classification of three cortical
       neuron types reveals two distinct adaptation mechanisms.
       *J. Neurophysiol.*, 107(6):1756-1775.
.. [2] Pozzorini C, Mensi S, Hagens O, Naud R, Koch C, Gerstner W (2015).
       Automated high-throughput characterization of single neurons by means
       of simplified spiking models. *PLoS Comput. Biol.*, 11(6), e1004275.
.. [3] NEST Simulator ``gif_psc_exp_multisynapse`` model,
       ``models/gif_psc_exp_multisynapse.h`` and
       ``models/gif_psc_exp_multisynapse.cpp``.
"""

from typing import Callable, Optional, Sequence

import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from ._base import NESTNeuron
from ._utils import is_tracer, propagator_exp

__all__ = [
    'gif_psc_exp_multisynapse',
]


# ---------------------------------------------------------------------------
# Proxy classes for mutable [i][j] access to ShortTermState arrays
# ---------------------------------------------------------------------------

class _RowProxy:
    """Read/write proxy for row ``row`` of a 2-D ShortTermState array.

    Supports ``proxy[j]`` (read) and ``proxy[j] = val`` (write) where the
    underlying state has shape ``(n_elems, *varshape)`` and units ``unit``.
    """
    __slots__ = ('_state', '_row', '_unit')

    def __init__(self, state, row, unit):
        self._state = state
        self._row = row
        self._unit = unit

    def __getitem__(self, j):
        raw = np.asarray(u.get_mantissa(self._state.value))
        return float(raw[self._row, j])

    def __setitem__(self, j, val):
        raw = np.asarray(u.get_mantissa(self._state.value)).copy()
        raw[self._row, j] = float(val)
        self._state.value = raw * self._unit


class _AdaptProxy:
    """Proxy exposing ``elems[i][j]`` indexing for adaptation state arrays."""
    __slots__ = ('_state', '_unit')

    def __init__(self, state, unit):
        self._state = state
        self._unit = unit

    def __getitem__(self, i):
        return _RowProxy(self._state, i, self._unit)


# ---------------------------------------------------------------------------
# Main model class
# ---------------------------------------------------------------------------

class gif_psc_exp_multisynapse(NESTNeuron):
    r"""Current-based generalized integrate-and-fire neuron (GIF) model
    with multiple synaptic time constants.

    This model implements the multisynapse extension of the generalized
    integrate-and-fire neuron according to Mensi et al. (2012) [1]_ and
    Pozzorini et al. (2015) [2]_, with exponential postsynaptic currents
    and an arbitrary number of receptor ports. It is a faithful
    re-implementation of the NEST simulator's ``gif_psc_exp_multisynapse``
    model, preserving exact (analytic) propagator integration, stochastic
    firing dynamics, update ordering, and all default parameter values.

    The model combines four key features:

    1. **Multiple receptor ports**: Each with independent exponential
       synaptic time constants (``tau_syn`` parameter)
    2. **Spike-triggered currents (STC)**: Post-spike current injection
       with multiple time scales (``tau_stc``, ``q_stc`` parameters)
    3. **Spike-frequency adaptation (SFA)**: Dynamic threshold modulation
       after each spike (``tau_sfa``, ``q_sfa`` parameters)
    4. **Stochastic spiking**: Exponential escape-rate firing with
       parameter ``lambda_0`` and threshold noise ``Delta_V``

    Mathematical Model
    ------------------

    **1. Membrane Potential Dynamics**

    The subthreshold membrane potential :math:`V(t)` evolves according to:

    .. math::

       C_m \frac{dV}{dt} = -g_L (V - E_L) - \sum_j \eta_j(t)
           + \sum_k I_{\mathrm{syn},k}(t) + I_e + I_{\mathrm{stim}}(t)

    where:

      - :math:`g_L (V - E_L)` is the passive leak current
      - :math:`\eta_j(t)` are spike-triggered currents (STCs)
      - :math:`I_{\mathrm{syn},k}(t)` are synaptic currents for each receptor port :math:`k`
      - :math:`I_e` is a constant external bias current
      - :math:`I_{\mathrm{stim}}(t)` is time-varying external input

    **2. Synaptic Currents (Multi-Receptor)**

    Each receptor port :math:`k` has an independent exponential synaptic current:

    .. math::

       \frac{dI_{\mathrm{syn},k}}{dt} = -\frac{I_{\mathrm{syn},k}}{\tau_{\mathrm{syn},k}}

    The number of receptor ports is determined by ``len(tau_syn)``. When
    connecting projections, specify ``receptor_type`` (1-based indexing,
    matching NEST convention) to target a specific port. Both excitatory
    and inhibitory connections can target any receptor port (weights can
    be positive or negative).

    **3. Spike-Triggered Currents (STC)**

    Each STC element :math:`\eta_j` evolves as:

    .. math::

       \tau_{\eta_j} \frac{d\eta_j}{dt} = -\eta_j

    Upon spike emission at time :math:`t_{\mathrm{sp}}`:

    .. math::

       \eta_j(t_{\mathrm{sp}}^+) = \eta_j(t_{\mathrm{sp}}^-) + q_{\eta_j}

    The total STC contribution is :math:`\sum_j \eta_j(t)`. STCs can model
    post-spike currents such as afterhyperpolarization (AHP) or
    afterdepolarization (ADP) depending on the sign of ``q_stc``.

    **4. Spike-Frequency Adaptation (SFA)**

    The firing threshold is dynamic, consisting of a base threshold
    :math:`V_{T^*}` plus adaptive components:

    .. math::

       V_T(t) = V_{T^*} + \sum_i \gamma_i(t)

    Each SFA element :math:`\gamma_i` evolves as:

    .. math::

       \tau_{\gamma_i} \frac{d\gamma_i}{dt} = -\gamma_i

    Upon spike emission:

    .. math::

       \gamma_i(t_{\mathrm{sp}}^+) = \gamma_i(t_{\mathrm{sp}}^-) + q_{\gamma_i}

    Positive ``q_sfa`` values increase the threshold after each spike,
    leading to spike-frequency adaptation.

    **5. Stochastic Spiking Mechanism**

    The neuron fires stochastically with an exponential escape-rate
    intensity:

    .. math::

       \lambda(t) = \lambda_0 \exp\!\left(\frac{V(t) - V_T(t)}{\Delta_V}\right)

    The probability of firing within a time step :math:`dt` is:

    .. math::

       P_{\mathrm{spike}}(\Delta t) = 1 - \exp(-\lambda(t) \cdot dt)

    At each non-refractory time step, a uniform random number
    :math:`r \in [0, 1)` is drawn. If :math:`r < P_{\mathrm{spike}}`, a
    spike is emitted. The stochasticity level :math:`\Delta_V` controls
    the sharpness of the firing threshold (smaller values → more
    deterministic).

    **6. Refractory Period**

    After a spike, the neuron enters an absolute refractory period of
    duration :math:`t_{\mathrm{ref}}`. During this period:

    * The membrane potential is clamped to :math:`V_{\mathrm{reset}}`
    * No spike can be emitted (firing intensity check is skipped)
    * Synaptic currents continue to evolve and receive inputs
    * STC and SFA elements continue to decay

    **Numerical Integration**

    The model uses **exact (analytic) integration** for all linear ODEs,
    matching NEST's propagator-based integration scheme. For each variable
    with dynamics :math:`\tau \frac{dx}{dt} = -x + f(t)`, the update over
    one time step :math:`h` is:

    .. math::

       x(t + h) = e^{-h/\tau} x(t) + \int_0^h e^{-(h-s)/\tau} f(t+s) \, ds

    For constant :math:`f`, this yields exact propagator coefficients. The
    membrane potential propagator accounts for coupling between :math:`V`
    and synaptic currents with potentially different time constants.

    **Update Order (Matching NEST)**

    Each simulation step follows this exact sequence (matching NEST's
    ``gif_psc_exp_multisynapse::update`` implementation):

    **Step 1: Adaptation Decay**
      - Compute total STC: :math:`\mathrm{stc\_total} = \sum_j \eta_j(t)`
      - Compute total threshold: :math:`V_T(t) = V_{T^*} + \sum_i \gamma_i(t)`
      - Decay all STC elements: :math:`\eta_j \leftarrow \eta_j \cdot e^{-dt/\tau_{\eta_j}}`
      - Decay all SFA elements: :math:`\gamma_i \leftarrow \gamma_i \cdot e^{-dt/\tau_{\gamma_i}}`

    **Step 2: Synaptic Current Processing (per receptor)**
      For each receptor port :math:`k`:
        - Compute propagated contribution to :math:`V`:
          :math:`\Delta V_k = P_{21,k} \cdot I_{\mathrm{syn},k}(t)`
        - Decay synaptic current:
          :math:`I_{\mathrm{syn},k} \leftarrow I_{\mathrm{syn},k} \cdot e^{-dt/\tau_{\mathrm{syn},k}}`
        - Add incoming spike weights:
          :math:`I_{\mathrm{syn},k} \leftarrow I_{\mathrm{syn},k} + w_k`

    **Step 3: Membrane Update and Spike Check**
      If **not refractory**:
        - Update membrane potential using exact propagator:

          .. math::

             V(t+dt) = P_{33} V(t) + P_{31} E_L + P_{30}(I_{\mathrm{stim}}(t) + I_e - \mathrm{stc\_total})
                 + \sum_k \Delta V_k

        - Compute firing intensity: :math:`\lambda = \lambda_0 \exp((V - V_T)/\Delta_V)`
        - Compute spike probability: :math:`p = 1 - \exp(-\lambda \cdot dt)`
        - Draw random number :math:`r \sim \mathrm{Uniform}(0, 1)`
        - If :math:`r < p`:
            * Emit spike
            * Jump STC elements: :math:`\eta_j \leftarrow \eta_j + q_{\eta_j}`
            * Jump SFA elements: :math:`\gamma_i \leftarrow \gamma_i + q_{\gamma_i}`
            * Set refractory counter: :math:`r_{\mathrm{count}} \leftarrow \lceil t_{\mathrm{ref}} / dt \rceil`
      If **refractory**:
        - Decrement refractory counter: :math:`r_{\mathrm{count}} \leftarrow r_{\mathrm{count}} - 1`
        - Clamp membrane potential: :math:`V \leftarrow V_{\mathrm{reset}}`

    **Step 4: Buffer External Current**
      Store :math:`I_{\mathrm{stim}}(t+dt)` for use in the next step
      (NEST ring-buffer semantics: one-step delay).

    **Differences from gif_psc_exp**

    Unlike :class:`gif_psc_exp` which has exactly two fixed synaptic
    channels (excitatory and inhibitory with ``tau_syn_ex``,
    ``tau_syn_in``), this model supports an arbitrary number of receptor
    ports specified by the ``tau_syn`` parameter. This enables:

    * Multi-receptor modeling (AMPA, NMDA, GABA_A, GABA_B, etc.)
    * Heterogeneous synaptic time constants within the same neuron
    * Flexible connectivity patterns with receptor-specific routing

    All spike weights are applied to the receptor port specified in the
    connection's ``receptor_type`` field (1-based indexing). Positive or
    negative weights are both allowed on any receptor.

    Parameters
    ----------
    in_size : int, tuple of int
        Shape of the neuron population.
    g_L : Quantity, ArrayLike, optional
        Leak conductance (nanosiemens). Default: 4.0 nS.
    E_L : Quantity, ArrayLike, optional
        Leak reversal potential (millivolts). Default: -70.0 mV.
    C_m : Quantity, ArrayLike, optional
        Membrane capacitance (picofarads). Default: 80.0 pF.
    V_reset : Quantity, ArrayLike, optional
        Reset potential (millivolts). Default: -55.0 mV.
    Delta_V : Quantity, ArrayLike, optional
        Voltage scale of stochastic firing (millivolts). Default: 0.5 mV.
    V_T_star : Quantity, ArrayLike, optional
        Base firing threshold (millivolts). Default: -35.0 mV.
    lambda_0 : float, optional
        Stochastic firing intensity at threshold (1/s). Default: 1.0 /s.
    t_ref : Quantity, ArrayLike, optional
        Absolute refractory period (milliseconds). Default: 4.0 ms.
    tau_syn : sequence of float, optional
        Synaptic time constants (milliseconds), one per receptor port.
        Specified as bare floats (not Quantities). Default: ``(2.0,)``.
    I_e : Quantity, ArrayLike, optional
        Constant external bias current (picoamperes). Default: 0.0 pA.
    tau_sfa : sequence of float, optional
        SFA time constants (milliseconds). Default: ``()`` (no adaptation).
    q_sfa : sequence of float, optional
        SFA jump values (millivolts). Default: ``()`` (no adaptation).
    tau_stc : sequence of float, optional
        STC time constants (milliseconds). Default: ``()`` (no STC).
    q_stc : sequence of float, optional
        STC jump values (picoamperes). Default: ``()`` (no STC).
    rng_key : jax.Array, optional
        JAX PRNG key for stochastic spike generation. Default: None.
    V_initializer : Callable, optional
        Initializer for membrane potential. Default: ``Constant(-70.0 mV)``.
    spk_fun : Callable, optional
        Surrogate gradient function. Default: ``ReluGrad()``.
    spk_reset : str, optional
        Spike reset mode (``'hard'`` or ``'soft'``). Default: ``'hard'``.
    name : str, optional
        Name of the neuron population. Default: None.

    State Variables
    ---------------
    V : HiddenState, shape ``(*in_size,)``
        Membrane potential in millivolts.
    i_syn : ShortTermState, shape ``(*in_size, n_receptors)``
        Synaptic currents in picoamperes.
    refractory_step_count : ShortTermState, shape ``(*in_size,)``
        Remaining refractory steps (int).
    I_stim : ShortTermState, shape ``(*in_size,)``
        Buffered external current (one-step delay).
    last_spike_time : ShortTermState, shape ``(*in_size,)``
        Time of last spike (milliseconds).

    See Also
    --------
    gif_psc_exp : Two-receptor GIF model
    iaf_psc_exp_multisynapse : Multi-receptor IAF model without adaptation

    References
    ----------
    .. [1] Mensi S et al. (2012). J. Neurophysiol., 107(6):1756-1775.
    .. [2] Pozzorini C et al. (2015). PLoS Comput. Biol., 11(6), e1004275.
    .. [3] NEST Simulator ``gif_psc_exp_multisynapse`` model documentation.
    """
    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        g_L: ArrayLike = 4.0 * u.nS,
        E_L: ArrayLike = -70.0 * u.mV,
        C_m: ArrayLike = 80.0 * u.pF,
        V_reset: ArrayLike = -55.0 * u.mV,
        Delta_V: ArrayLike = 0.5 * u.mV,
        V_T_star: ArrayLike = -35.0 * u.mV,
        lambda_0: float = 1.0,  # 1/s, as in NEST Python interface
        t_ref: ArrayLike = 4.0 * u.ms,
        tau_syn: Sequence[float] = (2.0,),  # ms values
        I_e: ArrayLike = 0.0 * u.pA,
        tau_sfa: Sequence[float] = (),  # ms values
        q_sfa: Sequence[float] = (),    # mV values
        tau_stc: Sequence[float] = (),  # ms values
        q_stc: Sequence[float] = (),    # pA values
        rng_key: Optional[jax.Array] = None,
        V_initializer: Callable = braintools.init.Constant(-70.0 * u.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        name: str = None,
        # Accepted for backward compatibility but unused:
        gsl_error_tol: ArrayLike = 1e-6,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        # Membrane parameters
        self.g_L = braintools.init.param(g_L, self.varshape)
        self.E_L = braintools.init.param(E_L, self.varshape)
        self.C_m = braintools.init.param(C_m, self.varshape)
        self.V_reset = braintools.init.param(V_reset, self.varshape)
        self.Delta_V = braintools.init.param(Delta_V, self.varshape)
        self.V_T_star = braintools.init.param(V_T_star, self.varshape)
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.I_e = braintools.init.param(I_e, self.varshape)

        # tau_syn: stored as plain numpy array of ms values (no units)
        if len(tau_syn) == 0:
            raise ValueError("'tau_syn' must have at least one element.")
        dftype = brainstate.environ.dftype()
        self.tau_syn = np.asarray([float(x) for x in tau_syn], dtype=dftype)

        # Stochastic spiking: lambda_0 in 1/s → store as 1/ms
        self.lambda_0 = lambda_0 / 1000.0

        # Adaptation parameters (stored as Python tuples of bare floats)
        self.tau_sfa = tuple(float(x) for x in tau_sfa)
        self.q_sfa = tuple(float(x) for x in q_sfa)
        self.tau_stc = tuple(float(x) for x in tau_stc)
        self.q_stc = tuple(float(x) for x in q_stc)

        if len(self.tau_sfa) != len(self.q_sfa):
            raise ValueError(
                f"'tau_sfa' and 'q_sfa' must have the same length. "
                f"Got {len(self.tau_sfa)} and {len(self.q_sfa)}."
            )
        if len(self.tau_stc) != len(self.q_stc):
            raise ValueError(
                f"'tau_stc' and 'q_stc' must have the same length. "
                f"Got {len(self.tau_stc)} and {len(self.q_stc)}."
            )

        self.n_stc = len(self.tau_stc)
        self.n_sfa = len(self.tau_sfa)

        # RNG key for stochastic spiking
        self._rng_key = rng_key

        # Initializer
        self.V_initializer = V_initializer

        self._validate_parameters()

        # Pre-compute refractory step count
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    @property
    def n_receptors(self):
        r"""Number of synaptic receptor ports."""
        return int(self.tau_syn.shape[0])

    def _validate_parameters(self):
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.C_m, self.g_L, self.Delta_V)):
            return
        if np.any(self.C_m <= 0.0 * u.pF):
            raise ValueError('Capacitance must be strictly positive.')
        if np.any(self.g_L <= 0.0 * u.nS):
            raise ValueError('Membrane conductance must be strictly positive.')
        if np.any(self.Delta_V <= 0.0 * u.mV):
            raise ValueError('Delta_V must be strictly positive.')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError('Refractory time must not be negative.')
        if self.lambda_0 < 0.0:
            raise ValueError('lambda_0 must not be negative.')
        for i, tau in enumerate(self.tau_sfa):
            if tau <= 0.0:
                raise ValueError(
                    f'All SFA time constants must be strictly positive '
                    f'(tau_sfa[{i}]={tau}).'
                )
        for i, tau in enumerate(self.tau_stc):
            if tau <= 0.0:
                raise ValueError(
                    f'All STC time constants must be strictly positive '
                    f'(tau_stc[{i}]={tau}).'
                )
        for i, tau in enumerate(self.tau_syn):
            if tau <= 0.0:
                raise ValueError(
                    f'All synaptic time constants must be strictly positive '
                    f'(tau_syn[{i}]={tau}).'
                )

[docs] def init_state(self, **kwargs): r"""Initialize all state variables. Creates membrane potential, synaptic currents, refractory counters, adaptation elements, buffered current, and internal RNG state. """ dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() V = braintools.init.param(self.V_initializer, self.varshape) self.V = brainstate.HiddenState(V) v_shape = V.shape # Synaptic currents: shape (*v_shape, n_receptors) — float64 for precision syn_shape = v_shape + (self.n_receptors,) self.i_syn = brainstate.ShortTermState( np.zeros(syn_shape, dtype=np.float64) * u.pA ) # STC elements: shape (n_stc, *v_shape) in pA — float64 for precision if self.n_stc > 0: stc_shape = (self.n_stc,) + v_shape self.stc_elems = brainstate.ShortTermState( np.zeros(stc_shape, dtype=np.float64) * u.pA ) else: self.stc_elems = None # SFA elements: shape (n_sfa, *v_shape) in mV — float64 for precision if self.n_sfa > 0: sfa_shape = (self.n_sfa,) + v_shape self.sfa_elems = brainstate.ShortTermState( np.zeros(sfa_shape, dtype=np.float64) * u.mV ) else: self.sfa_elems = None self.last_spike_time = brainstate.ShortTermState( u.math.full(v_shape, -1e7 * u.ms) ) self.refractory_step_count = brainstate.ShortTermState( u.math.full(v_shape, 0, dtype=ditype) ) self.I_stim = brainstate.ShortTermState( u.math.full(v_shape, 0.0 * u.pA, dtype=dftype) ) # Caches for pre-decay totals (accessed via _stc_val / _sfa_val) V_T_star_mV = float(np.asarray(u.get_mantissa(self.V_T_star))) self._stc_val_cache = np.zeros(v_shape, dtype=np.float64) self._sfa_val_cache = np.full(v_shape, V_T_star_mV, dtype=np.float64) # RNG state — stored as ShortTermState so brainstate.transform.for_loop can track it rng_init = self._rng_key if self._rng_key is not None else jax.random.PRNGKey(0) self._rng_state = brainstate.ShortTermState(rng_init)
[docs] def get_spike(self, V: ArrayLike = None): r"""Compute spike output via surrogate gradient function.""" V = self.V.value if V is None else V v_scaled = (V - self.V_reset) / self.Delta_V return self.spk_fun(v_scaled)
def _parse_spike_events(self, spike_events, v_shape): r"""Parse spike event descriptors into a per-receptor weight array. Parameters ---------- spike_events : iterable or None Events as ``(receptor_type, weight)`` tuples or dicts with ``'receptor_type'`` and ``'weight'`` keys. ``None`` → no events. v_shape : tuple of int Shape of the neuron population state. Returns ------- out : np.ndarray, shape ``v_shape + (n_receptors,)``, dtype float64 Total weight (pA) arriving at each receptor this step. Raises ------ ValueError If any ``receptor_type`` is outside ``[1, n_receptors]``. """ out = np.zeros(v_shape + (self.n_receptors,), dtype=np.float64) if spike_events is None: return out for ev in spike_events: if isinstance(ev, dict): receptor = int(ev.get('receptor_type', ev.get('receptor', 1))) weight = ev.get('weight', 0.0) else: receptor, weight = ev receptor = int(receptor) if receptor < 1 or receptor > self.n_receptors: raise ValueError( f'Receptor type {receptor} out of range [1, {self.n_receptors}].' ) w_np = np.asarray(u.math.asarray(weight / u.pA), dtype=np.float64) out[..., receptor - 1] += np.broadcast_to(w_np, v_shape) return out # ------------------------------------------------------------------ # Adaptation element proxy properties # ------------------------------------------------------------------ @property def _stc_elems(self): """Proxy for ``stc_elems[i][j]`` read/write access (units: pA).""" if self.stc_elems is None: raise AttributeError('No STC elements configured (tau_stc is empty).') return _AdaptProxy(self.stc_elems, u.pA) @property def _sfa_elems(self): """Proxy for ``sfa_elems[i][j]`` read/write access (units: mV).""" if self.sfa_elems is None: raise AttributeError('No SFA elements configured (tau_sfa is empty).') return _AdaptProxy(self.sfa_elems, u.mV) @property def _stc_val(self): """Pre-decay STC totals (pA) from the last completed update step.""" return self._stc_val_cache @property def _sfa_val(self): """Pre-decay SFA threshold totals (mV) from the last completed update step.""" return self._sfa_val_cache # ------------------------------------------------------------------ # Main update method # ------------------------------------------------------------------
[docs] def update(self, x=0.0 * u.pA, spike_events=None, receptor_weights=None): r"""Update neuron state for one simulation step. Follows NEST's ``gif_psc_exp_multisynapse::update`` exactly: 1. Compute pre-decay STC/SFA totals, then decay. 2. Propagate + decay + inject per-receptor synaptic currents. 3. If not refractory: exact-propagator membrane update + stochastic spike check; if spiked, jump STC/SFA and set refractory counter. If refractory: decrement counter, clamp V to V_reset. 4. Buffer external current for next step (one-step delay). Parameters ---------- x : Quantity, optional External current input (pA), buffered by one step. Default: 0 pA. spike_events : iterable or None, optional Receptor-indexed spike events. Default: None. receptor_weights : jax.Array or None, optional Pre-computed per-receptor weight array, shape ``v_shape + (n_receptors,)``. When provided, these weights are added directly to the synaptic currents after decay (same semantics as ``spike_events``). Useful inside ``brainstate.transform.for_loop`` where Python-level spike_events iteration is not traceable. Default: None. Returns ------- jax.Array Binary spike output (float), shape ``self.V.value.shape``. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() h = float(u.get_mantissa(dt / u.ms)) # step in ms (concrete Python float) dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() v_shape = self.V.value.shape # ---- Strip units from parameters (concrete Python values via get_mantissa) ---- tau_m_ms = np.asarray(u.get_mantissa(self.C_m / self.g_L / u.ms)) C_m_pF = np.asarray(u.get_mantissa(self.C_m / u.pF)) E_L_mV = np.asarray(u.get_mantissa(self.E_L / u.mV)) V_reset_mV = np.asarray(u.get_mantissa(self.V_reset / u.mV)) Delta_V_mV = np.asarray(u.get_mantissa(self.Delta_V / u.mV)) V_T_star_mV = np.asarray(u.get_mantissa(self.V_T_star / u.mV)) I_e_pA = np.asarray(u.get_mantissa(self.I_e / u.pA)) # ---- Read state (JAX arrays, compatible with for_loop tracing) ---- V_mV = u.get_mantissa(self.V.value / u.mV) i_syn_pA = u.get_mantissa(self.i_syn.value / u.pA) r = self.refractory_step_count.value I_stim_pA = u.get_mantissa(self.I_stim.value / u.pA) # ---- Propagator coefficients ---- P33 = np.exp(-h / tau_m_ms) P30 = -1.0 / C_m_pF * np.expm1(-h / tau_m_ms) * tau_m_ms P31 = -np.expm1(-h / tau_m_ms) P11_syn = np.exp(-h / self.tau_syn) # shape (n_receptors,) P21_syn = np.stack([ propagator_exp(ts * np.ones(v_shape), tau_m_ms, C_m_pF, h) for ts in self.tau_syn ], axis=-1) # shape v_shape + (n_receptors,) # ---- Step 1: Adaptation — compute pre-decay totals, then decay ---- if self.n_stc > 0: stc_elems_pA = jnp.asarray(u.get_mantissa(self.stc_elems.value / u.pA)) stc_total_pA = jnp.sum(stc_elems_pA, axis=0) # pre-decay total P_stc = np.exp(-h / np.array(self.tau_stc, dtype=np.float64)) P_stc_bc = P_stc.reshape((-1,) + (1,) * len(v_shape)) stc_elems_pA = stc_elems_pA * P_stc_bc # decay else: stc_total_pA = jnp.zeros(v_shape) stc_elems_pA = None if self.n_sfa > 0: sfa_elems_mV = jnp.asarray(u.get_mantissa(self.sfa_elems.value / u.mV)) sfa_total_mV = V_T_star_mV + jnp.sum(sfa_elems_mV, axis=0) # pre-decay P_sfa = np.exp(-h / np.array(self.tau_sfa, dtype=np.float64)) P_sfa_bc = P_sfa.reshape((-1,) + (1,) * len(v_shape)) sfa_elems_mV = sfa_elems_mV * P_sfa_bc # decay else: sfa_total_mV = jnp.broadcast_to(jnp.asarray(V_T_star_mV), v_shape) sfa_elems_mV = None # Cache pre-decay totals for external inspection self._stc_val_cache = stc_total_pA self._sfa_val_cache = sfa_total_mV # ---- Step 2: Synaptic currents (propagate, decay, inject) ---- # Propagate contribution to V using pre-decay i_syn sum_syn_pot = jnp.sum(P21_syn * i_syn_pA, axis=-1) # shape v_shape # Decay each receptor current i_syn_pA = i_syn_pA * P11_syn # Parse spike events and add delta inputs (both go to receptors) w_by_rec = self._parse_spike_events(spike_events, v_shape) w_default_pA = u.get_mantissa(self.sum_delta_inputs(0.0 * u.pA) / u.pA) w_by_rec[..., 0] = w_by_rec[..., 0] + np.broadcast_to(np.asarray(w_default_pA), v_shape) i_syn_pA = i_syn_pA + jnp.asarray(w_by_rec) # add spike weights AFTER decay if receptor_weights is not None: i_syn_pA = i_syn_pA + receptor_weights # Buffer current for NEXT step (NEST ring-buffer semantics) new_I_stim = self.sum_current_inputs(x, self.V.value) # ---- Step 3: Membrane update and stochastic spike check ---- not_refractory = (r == 0) # Candidate V for non-refractory neurons (NEST propagator update) V_candidate_mV = ( P30 * (I_stim_pA + I_e_pA - stc_total_pA) + P33 * V_mV + P31 * E_L_mV + sum_syn_pot ) # Stochastic spike check (only for non-refractory) new_rng, subkey = jax.random.split(self._rng_state.value) self._rng_state.value = new_rng rand_vals = jax.random.uniform(subkey, shape=v_shape) exp_arg = jnp.clip( (V_candidate_mV - sfa_total_mV) / Delta_V_mV, -500.0, 500.0 ) lam = self.lambda_0 * jnp.exp(exp_arg) # 1/ms spike_prob = jnp.clip(-jnp.expm1(-lam * h), 0.0, 1.0) spike_mask = not_refractory & (rand_vals < spike_prob) # Final V: spike or refractory → V_reset; else → V_candidate new_V_mV = jnp.where( not_refractory & ~spike_mask, V_candidate_mV, V_reset_mV, ) # STC jumps on spike (applied to already-decayed elements) if self.n_stc > 0: q_stc_arr = np.array(self.q_stc, dtype=np.float64) stc_elems_pA = jnp.asarray(stc_elems_pA) for i in range(self.n_stc): stc_elems_pA = stc_elems_pA.at[i].set( jnp.where(spike_mask, stc_elems_pA[i] + q_stc_arr[i], stc_elems_pA[i]) ) # SFA jumps on spike if self.n_sfa > 0: q_sfa_arr = np.array(self.q_sfa, dtype=np.float64) sfa_elems_mV = jnp.asarray(sfa_elems_mV) for i in range(self.n_sfa): sfa_elems_mV = sfa_elems_mV.at[i].set( jnp.where(spike_mask, sfa_elems_mV[i] + q_sfa_arr[i], sfa_elems_mV[i]) ) # Update refractory counter: # spike → ref_count # refractory → r - 1 # otherwise → r (keep 0) ref_count_jax = u.get_mantissa(self.ref_count) new_r = jnp.where( spike_mask, ref_count_jax, jnp.where(not_refractory, r, r - 1), ) # ---- Write back state ---- self.V.value = new_V_mV * u.mV self.i_syn.value = i_syn_pA * u.pA if self.n_stc > 0: self.stc_elems.value = stc_elems_pA * u.pA if self.n_sfa > 0: self.sfa_elems.value = sfa_elems_mV * u.mV self.refractory_step_count.value = new_r.astype(ditype) self.I_stim.value = new_I_stim + u.math.zeros(v_shape) * u.pA last_spike_time = u.math.where( spike_mask, t + dt, self.last_spike_time.value ) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) return jnp.asarray(spike_mask, dtype=dftype)