# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable, Iterable
import brainstate
import braintools
import jax
import jax.numpy as jnp
import numpy as np
import saiunit as u
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict
from ._base import NESTNeuron
from ._utils import is_tracer, validate_aeif_overflow, AdaptiveRungeKuttaStep
__all__ = [
'aeif_cond_beta_multisynapse',
]
class aeif_cond_beta_multisynapse(NESTNeuron):
r"""NEST-compatible ``aeif_cond_beta_multisynapse`` neuron model.
Conductance-based adaptive exponential integrate-and-fire neuron with
beta-shaped synaptic conductances and an arbitrary number of receptor ports.
Implements NEST's ``aeif_cond_beta_multisynapse`` model with source-level
parity in update ordering, refractory handling, and spike detection.
This model extends the adaptive exponential integrate-and-fire (AdEx) framework
[1]_ with beta-function synaptic conductances instead of exponential or alpha
shapes. Each receptor port maintains independent rise/decay time constants and
reversal potentials, enabling multi-receptor networks (e.g., AMPA + GABA_A).
Parameters
----------
in_size : Size
Population shape as integer, tuple, or Size object. Required.
V_peak : ArrayLike, optional
Spike detection threshold (mV). Used when ``Delta_T > 0``; otherwise
``V_th`` is used. Must satisfy ``V_peak >= V_th``. Default: 0.0 mV.
V_reset : ArrayLike, optional
Post-spike reset potential (mV). Must satisfy ``V_reset < V_peak``.
Default: -60.0 mV.
t_ref : ArrayLike, optional
Absolute refractory period (ms). During refractory, ``dV/dt = 0`` and
voltage is clamped to ``V_reset``. Default: 0.0 ms (no refractory).
g_L : ArrayLike, optional
Leak conductance (nS). Must be strictly positive. Default: 30.0 nS.
C_m : ArrayLike, optional
Membrane capacitance (pF). Must be strictly positive. Default: 281.0 pF.
E_L : ArrayLike, optional
Leak reversal potential (mV). Default: -70.6 mV.
Delta_T : ArrayLike, optional
Exponential slope factor (mV). Must be non-negative. When ``Delta_T = 0``,
reduces to LIF-like dynamics. Default: 2.0 mV.
tau_w : ArrayLike, optional
Adaptation time constant (ms). Must be strictly positive. Default: 144.0 ms.
a : ArrayLike, optional
Subthreshold adaptation coupling (nS). Default: 4.0 nS.
b : ArrayLike, optional
Spike-triggered adaptation increment (pA). Added to ``w`` on each spike.
Default: 80.5 pA.
V_th : ArrayLike, optional
Spike initiation threshold (mV) for exponential term. Must satisfy
``V_th <= V_peak``. Default: -50.4 mV.
tau_rise : ArrayLike, optional
Synaptic rise time constants (ms) per receptor, shape ``(n_receptors,)``.
Must be strictly positive and satisfy ``tau_rise <= tau_decay`` element-wise.
Default: (2.0,) ms (single receptor).
tau_decay : ArrayLike, optional
Synaptic decay time constants (ms) per receptor, shape ``(n_receptors,)``.
Must be strictly positive and satisfy ``tau_decay >= tau_rise`` element-wise.
Default: (20.0,) ms (single receptor).
E_rev : ArrayLike, optional
Reversal potentials (mV) per receptor, shape ``(n_receptors,)``.
Default: (0.0,) mV (excitatory-like).
I_e : ArrayLike, optional
Constant external current (pA). Default: 0.0 pA.
gsl_error_tol : ArrayLike, optional
RKF45 local error tolerance (unitless). Smaller values improve accuracy
but increase computational cost. Must be strictly positive. Default: 1e-6.
V_initializer : Callable, optional
Membrane potential initializer. Default: Constant(-70.6 mV).
g_initializer : Callable, optional
Conductance state initializer with shape ``[..., n_receptors]``.
Default: Constant(0.0 nS).
w_initializer : Callable, optional
Adaptation current initializer. Default: Constant(0.0 pA).
spk_fun : Callable, optional
Surrogate gradient function for differentiable spike generation.
Default: ReluGrad().
spk_reset : str, optional
Spike reset mode. ``'hard'`` (stop_gradient, matches NEST) or ``'soft'``
(subtract threshold). Default: ``'hard'``.
ref_var : bool, optional
If True, expose ``refractory`` state variable (boolean indicator).
Default: False.
name : str, optional
Instance name. If None, auto-generated.
Parameter Mapping
-----------------
======================== ===================== ===============================================
**BrainPy Parameter** **NEST Parameter** **Description**
======================== ===================== ===============================================
``in_size`` (model count) Population shape
``V_peak`` ``V_peak`` Spike detection threshold
``V_reset`` ``V_reset`` Reset potential
``t_ref`` ``t_ref`` Refractory period
``g_L`` ``g_L`` Leak conductance
``C_m`` ``C_m`` Membrane capacitance
``E_L`` ``E_L`` Leak reversal
``Delta_T`` ``Delta_T`` Slope factor
``tau_w`` ``tau_w`` Adaptation time constant
``a`` ``a`` Subthreshold adaptation
``b`` ``b`` Spike-triggered adaptation
``V_th`` ``V_th`` Exponential threshold
``tau_rise`` ``tau_rise`` Rise time per receptor
``tau_decay`` ``tau_decay`` Decay time per receptor
``E_rev`` ``E_rev`` Reversal potential per receptor
``I_e`` ``I_e`` Constant external current
``gsl_error_tol`` ``gsl_error_tol`` RKF45 tolerance
======================== ===================== ===============================================
Mathematical Model
------------------
**1. Membrane Dynamics**
The membrane voltage :math:`V` evolves according to:
.. math::
C_m \frac{dV}{dt} = -g_L (V - E_L) + g_L \Delta_T \exp\!\left(\frac{V - V_{th}}{\Delta_T}\right)
+ \sum_{k=1}^{n_{\text{rec}}} g_k (E_{\text{rev},k} - V)
- w + I_e + I_{\text{stim}},
where:
- :math:`C_m` -- membrane capacitance (pF)
- :math:`g_L` -- leak conductance (nS)
- :math:`E_L` -- leak reversal potential (mV)
- :math:`\Delta_T` -- exponential slope factor (mV)
- :math:`V_{th}` -- spike initiation threshold (mV)
- :math:`g_k` -- conductance of receptor :math:`k` (nS)
- :math:`E_{\text{rev},k}` -- reversal potential of receptor :math:`k` (mV)
- :math:`w` -- adaptation current (pA)
- :math:`I_e` -- constant external current (pA)
- :math:`I_{\text{stim}}` -- delayed injected current (pA)
During refractory period, :math:`dV/dt = 0` and :math:`V` is clamped to
:math:`V_{\text{reset}}`. Outside refractory, the exponential term uses
:math:`\min(V, V_{\text{peak}})` to prevent numerical overflow.
**2. Adaptation Dynamics**
The adaptation current :math:`w` follows:
.. math::
\tau_w \frac{dw}{dt} = a (V - E_L) - w,
where :math:`a` (nS) couples subthreshold membrane voltage fluctuations to
adaptation. On each spike, :math:`w \leftarrow w + b` implements spike-
triggered adaptation.
**3. Beta-Function Synaptic Conductances**
Each receptor :math:`k` maintains two state variables:
.. math::
\frac{d\,dg_k}{dt} = -\frac{dg_k}{\tau_{\text{rise},k}}, \quad
\frac{dg_k}{dt} = dg_k - \frac{g_k}{\tau_{\text{decay},k}}.
Incoming spikes with weight :math:`w_k` (nS) increment the auxiliary state:
.. math::
dg_k \leftarrow dg_k + g_{0,k} w_k,
where :math:`g_{0,k}` is the beta normalization factor ensuring unit weight
produces unit peak conductance:
.. math::
g_{0,k} = \frac{1/\tau_{\text{rise},k} - 1/\tau_{\text{decay},k}}{\exp(-t_{\text{peak}}/\tau_{\text{decay},k}) - \exp(-t_{\text{peak}}/\tau_{\text{rise},k})},
with :math:`t_{\text{peak}} = \tau_{\text{decay},k} \tau_{\text{rise},k} \log(\tau_{\text{decay},k}/\tau_{\text{rise},k}) / (\tau_{\text{decay},k} - \tau_{\text{rise},k})`.
In the equal-time-constant limit, this reduces to the alpha normalization
:math:`e / \tau`.
**4. Spike Detection and Reset**
A spike is detected when:
- :math:`V \geq V_{\text{peak}}` if :math:`\Delta_T > 0`
- :math:`V \geq V_{th}` if :math:`\Delta_T = 0`
Upon spike detection (within RKF45 substeps):
1. :math:`V \leftarrow V_{\text{reset}}`
2. :math:`w \leftarrow w + b`
3. Refractory counter :math:`r \leftarrow \lceil t_{\text{ref}} / dt \rceil + 1` (if ``t_ref > 0``)
**5. Update Order (NEST Semantics)**
Each simulation step :math:`(t, t+dt]` proceeds as:
1. Integrate ODEs using adaptive RKF45 with internal substeps
2. Inside integration: apply refractory clamp and spike/reset logic
3. Decrement refractory counter once (outside integration)
4. Apply incoming spike events to ``dg`` states with beta normalization
5. Store continuous current input for next step (one-step delay)
**Computational Notes**
- **Numerical integration**: Runge-Kutta-Fehlberg (RKF45) adaptive solver
with local error tolerance ``gsl_error_tol``. Internal step size adapts
dynamically and persists across simulation steps.
- **Refractory handling**: During refractory, effective voltage is clamped
to ``V_reset`` for all ODE terms, including adaptation coupling.
- **Overflow protection**: Exponential term uses :math:`\min(V, V_{\text{peak}})`
outside refractory to prevent :math:`\exp(\cdot)` overflow. Validation
ensures :math:`(V_{\text{peak}} - V_{th}) / \Delta_T` stays below overflow
threshold when :math:`\Delta_T > 0`.
- **Spike event format**: ``spike_events`` must be an iterable of
``(receptor_type, weight)`` tuples or dicts with keys ``receptor_type``/
``receptor`` and ``weight``. Receptor types are 1-based (NEST convention):
``1 <= receptor_type <= n_receptors``. Weights (nS) must be non-negative.
- **Default input mapping**: ``add_delta_input`` stream is mapped to receptor 1;
weights must be non-negative.
- **Instability detection**: Integration raises ``ValueError`` if
:math:`V < -1000` mV or :math:`|w| > 10^6` pA, indicating numerical collapse.
Attributes
----------
V : HiddenState
Membrane potential (mV), shape ``(*in_size,)``.
w : HiddenState
Adaptation current (pA), shape ``(*in_size,)``.
dg : ShortTermState
Beta auxiliary states (nS/ms), shape ``(*in_size, n_receptors)``.
g : HiddenState
Receptor conductances (nS), shape ``(*in_size, n_receptors)``.
refractory_step_count : ShortTermState
Remaining refractory steps (int32), shape ``(*in_size,)``.
integration_step : ShortTermState
Persistent RKF45 step size (ms), shape ``(*in_size,)``.
I_stim : ShortTermState
One-step delayed current buffer (pA), shape ``(*in_size,)``.
last_spike_time : ShortTermState
Last spike time (ms), shape ``(*in_size,)``. Initialized to -1e7 ms.
refractory : ShortTermState, optional
Boolean refractory indicator, shape ``(*in_size,)``. Only present if
``ref_var=True``.
Raises
------
ValueError
If ``tau_rise.size != tau_decay.size != E_rev.size``.
ValueError
If any ``tau_rise <= 0`` or ``tau_decay <= 0``.
ValueError
If any ``tau_decay < tau_rise``.
ValueError
If any ``V_peak < V_th`` or ``V_reset >= V_peak``.
ValueError
If ``Delta_T < 0`` or ``C_m <= 0`` or ``t_ref < 0`` or ``tau_w <= 0``.
ValueError
If ``gsl_error_tol <= 0``.
ValueError
If :math:`(V_{\text{peak}} - V_{th}) / \Delta_T` exceeds overflow threshold
(when :math:`\Delta_T > 0`).
ValueError
During ``update``, if receptor type out of range ``[1, n_receptors]``.
ValueError
During ``update``, if synaptic weight is negative (conductance constraint).
ValueError
During ``update``, if numerical instability detected (:math:`V < -1000` mV
or :math:`|w| > 10^6` pA).
See Also
--------
aeif_cond_alpha_multisynapse : Alpha-function variant
aeif_cond_exp : Single exponential synapse
aeif_psc_exp : Current-based AdEx
Notes
-----
- Default ``t_ref = 0`` matches NEST and allows multiple spikes per timestep.
Set ``t_ref > 0`` to enforce physiological refractory periods.
- Beta conductances provide more realistic synaptic shapes than single
exponentials but require two state variables per receptor (``dg`` and ``g``).
- When ``tau_rise = tau_decay``, normalization degenerates to alpha-function
limit :math:`e / \tau`.
References
----------
.. [1] Brette R, Gerstner W (2005). Adaptive exponential integrate-and-fire
model as an effective description of neuronal activity.
Journal of Neurophysiology, 94:3637-3642.
DOI: https://doi.org/10.1152/jn.00686.2005
.. [2] Roth A, van Rossum M (2013). Modeling synapses.
In *Computational Modeling Methods for Neuroscientists*.
MIT Press, Cambridge, MA.
.. [3] NEST 3.9+ source: ``models/aeif_cond_beta_multisynapse.h`` and
``models/aeif_cond_beta_multisynapse.cpp``.
Examples
--------
Create a two-receptor neuron (excitatory + inhibitory):
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> neuron = bp.aeif_cond_beta_multisynapse(
... in_size=10,
... tau_rise=(2.0, 0.5) * u.ms,
... tau_decay=(20.0, 8.0) * u.ms,
... E_rev=(0.0, -80.0) * u.mV, # excitatory, inhibitory
... )
>>> neuron.n_receptors
2
Simulate with receptor-specific spike events:
.. code-block:: python
>>> import brainstate as bst
>>> with bst.environ.context(dt=0.1 * u.ms):
... neuron.init_all_states()
... # Excitatory spike to receptor 1
... events = [(1, 5.0 * u.nS)]
... spk = neuron.update(x=100.0 * u.pA, spike_events=events)
... print(neuron.V.value) # doctest: +SKIP
Multi-receptor dictionary format:
.. code-block:: python
>>> events = [
... {'receptor_type': 1, 'weight': 3.0 * u.nS},
... {'receptor_type': 2, 'weight': 2.0 * u.nS},
... ]
>>> spk = neuron.update(spike_events=events) # doctest: +SKIP
"""
__module__ = 'brainpy.state'
_MIN_H = 1e-8 * u.ms # ms
_MAX_ITERS = 100000
_EPS = np.finfo(np.float64).eps
[docs]
def __init__(
self,
in_size: Size,
V_peak: ArrayLike = 0.0 * u.mV,
V_reset: ArrayLike = -60.0 * u.mV,
t_ref: ArrayLike = 0.0 * u.ms,
g_L: ArrayLike = 30.0 * u.nS,
C_m: ArrayLike = 281.0 * u.pF,
E_L: ArrayLike = -70.6 * u.mV,
Delta_T: ArrayLike = 2.0 * u.mV,
tau_w: ArrayLike = 144.0 * u.ms,
a: ArrayLike = 4.0 * u.nS,
b: ArrayLike = 80.5 * u.pA,
V_th: ArrayLike = -50.4 * u.mV,
tau_rise: ArrayLike = (2.0,) * u.ms,
tau_decay: ArrayLike = (20.0,) * u.ms,
E_rev: ArrayLike = (0.0,) * u.mV,
I_e: ArrayLike = 0.0 * u.pA,
gsl_error_tol: ArrayLike = 1e-6,
V_initializer: Callable = braintools.init.Constant(-70.6 * u.mV),
g_initializer: Callable = braintools.init.Constant(0.0 * u.nS),
w_initializer: Callable = braintools.init.Constant(0.0 * u.pA),
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
ref_var: bool = False,
name: str = None,
):
r"""Initialize aeif_cond_beta_multisynapse neuron.
All parameters are documented in the class docstring.
"""
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
self.V_peak = braintools.init.param(V_peak, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.g_L = braintools.init.param(g_L, self.varshape)
self.C_m = braintools.init.param(C_m, self.varshape)
self.E_L = braintools.init.param(E_L, self.varshape)
self.Delta_T = braintools.init.param(Delta_T, self.varshape)
self.tau_w = braintools.init.param(tau_w, self.varshape)
self.a = braintools.init.param(a, self.varshape)
self.b = braintools.init.param(b, self.varshape)
self.V_th = braintools.init.param(V_th, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
self.gsl_error_tol = gsl_error_tol
dftype = brainstate.environ.dftype()
self.tau_rise = np.asarray(u.math.asarray(tau_rise / u.ms), dtype=dftype).reshape(-1)
self.tau_decay = np.asarray(u.math.asarray(tau_decay / u.ms), dtype=dftype).reshape(-1)
self.E_rev = np.asarray(u.math.asarray(E_rev / u.mV), dtype=dftype).reshape(-1)
self.V_initializer = V_initializer
self.g_initializer = g_initializer
self.w_initializer = w_initializer
self.ref_var = ref_var
self._validate_parameters()
self._g0 = np.asarray(
[self._beta_normalization_factor_scalar(tr, td) for tr, td in zip(self.tau_rise, self.tau_decay)],
dtype=dftype,
)
# Per-receptor unit-aware time constants for the vector field.
self._tau_rise_ms = jnp.asarray(self.tau_rise) * u.ms
self._tau_decay_ms = jnp.asarray(self.tau_decay) * u.ms
self._E_rev_mV = jnp.asarray(self.E_rev) * u.mV
self.integrator = AdaptiveRungeKuttaStep(
method='RKF45',
vf=self._vector_field,
event_fn=self._event_fn,
min_h=self._MIN_H,
max_iters=self._MAX_ITERS,
atol=self.gsl_error_tol,
dt=brainstate.environ.get_dt()
)
# other variable
ditype = brainstate.environ.ditype()
dt = brainstate.environ.get_dt()
self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)
@property
def n_receptors(self):
r"""Number of receptor ports.
Returns
-------
int
Number of receptor types, inferred from ``tau_rise.size``.
"""
return int(self.tau_rise.size)
@property
def recordables(self):
r"""List of recordable state variable names.
Returns
-------
list of str
Dynamic recordables following NEST naming: ``['V_m', 'w', 'g_1', 'g_2', ..., 'g_n']``.
"""
return ['V_m', 'w', *[f'g_{i + 1}' for i in range(self.n_receptors)]]
@classmethod
def _beta_normalization_factor_scalar(cls, tau_rise: float, tau_decay: float):
r"""Compute beta normalization factor for single receptor.
Ensures unit weight produces unit peak conductance. Implements NEST's
beta normalization formula, degenerating to alpha normalization
:math:`e / \tau` when :math:`\tau_{\text{rise}} = \tau_{\text{decay}}`.
Parameters
----------
tau_rise : float
Synaptic rise time constant (ms, unitless).
tau_decay : float
Synaptic decay time constant (ms, unitless).
Returns
-------
float
Normalization factor :math:`g_0` such that unit weight produces unit peak.
If :math:`\tau_{\text{rise}} \approx \tau_{\text{decay}}`, returns :math:`e / \tau_{\text{decay}}`.
Notes
-----
The normalization factor is:
.. math::
g_0 = \frac{1/\tau_{\text{rise}} - 1/\tau_{\text{decay}}}{\exp(-t_{\text{peak}}/\tau_{\text{decay}}) - \exp(-t_{\text{peak}}/\tau_{\text{rise}})},
where :math:`t_{\text{peak}} = \tau_{\text{decay}} \tau_{\text{rise}} \log(\tau_{\text{decay}}/\tau_{\text{rise}}) / (\tau_{\text{decay}} - \tau_{\text{rise}})`.
"""
tau_difference = tau_decay - tau_rise
peak_value = 0.0
if abs(tau_difference) > cls._EPS:
t_peak = tau_decay * tau_rise * np.log(tau_decay / tau_rise) / tau_difference
peak_value = np.exp(-t_peak / tau_decay) - np.exp(-t_peak / tau_rise)
if abs(peak_value) < cls._EPS:
return np.e / tau_decay
return (1.0 / tau_rise - 1.0 / tau_decay) / peak_value
def _validate_parameters(self):
r"""Validate model parameters at initialization.
Raises
------
ValueError
If parameter constraints are violated (see class docstring for details).
Specific checks include:
- Receptor array size consistency (``tau_rise``, ``tau_decay``, ``E_rev``)
- Strict positivity (``tau_rise``, ``tau_decay``, ``C_m``, ``tau_w``, ``gsl_error_tol``)
- Ordering constraints (``tau_decay >= tau_rise``, ``V_peak >= V_th``, ``V_reset < V_peak``)
- Non-negativity (``Delta_T``, ``t_ref``)
- Overflow prevention (exponential term when ``Delta_T > 0``)
"""
v_reset = self.V_reset
v_peak = self.V_peak
v_th = self.V_th
delta_t = self.Delta_T / u.mV
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (v_reset, v_peak, v_th, delta_t)):
return
if self.E_rev.size != self.tau_rise.size or self.E_rev.size != self.tau_decay.size:
raise ValueError(
'The reversal potential, synaptic rise time and synaptic decay time arrays must have the same size.')
if np.any(self.tau_rise <= 0.0) or np.any(self.tau_decay <= 0.0):
raise ValueError('All synaptic time constants must be strictly positive')
if np.any(self.tau_decay < self.tau_rise):
raise ValueError('Synaptic rise time must be smaller than or equal to decay time.')
if np.any(v_peak < v_th):
raise ValueError('V_peak >= V_th required.')
if np.any(v_reset >= v_peak):
raise ValueError('Ensure that: V_reset < V_peak .')
if np.any(delta_t < 0.0):
raise ValueError('Delta_T must be positive.')
if np.any(self.C_m <= 0.0 * u.pF):
raise ValueError('Capacitance must be strictly positive.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory time cannot be negative.')
if np.any(self.tau_w <= 0.0 * u.ms):
raise ValueError('All time constants must be strictly positive.')
if np.any(self.gsl_error_tol <= 0.0):
raise ValueError('The gsl_error_tol must be strictly positive.')
# Mirror NEST overflow guard for exponential term at spike time.
validate_aeif_overflow(v_peak, v_th, delta_t)
[docs]
def init_state(self, **kwargs):
r"""Initialize all state variables.
Creates ``HiddenState`` and ``ShortTermState`` attributes for membrane
potential, adaptation current, receptor conductances, refractory counters,
RKF45 step size, and delayed current buffer.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
Notes
-----
Initializes:
- ``V`` (HiddenState): membrane potential from ``V_initializer``
- ``w`` (HiddenState): adaptation current from ``w_initializer``
- ``dg`` (ShortTermState): beta auxiliary states, initialized to zero
- ``g`` (HiddenState): receptor conductances from ``g_initializer``
- ``last_spike_time`` (ShortTermState): initialized to -1e7 ms
- ``refractory_step_count`` (ShortTermState): initialized to 0
- ``integration_step`` (ShortTermState): RKF45 step size, initialized to ``dt``
- ``I_stim`` (ShortTermState): delayed current buffer, initialized to 0 pA
- ``refractory`` (ShortTermState, optional): boolean indicator if ``ref_var=True``
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
dt = brainstate.environ.get_dt()
V = braintools.init.param(self.V_initializer, self.varshape)
w = braintools.init.param(self.w_initializer, self.varshape)
g = braintools.init.param(self.g_initializer, self.varshape + (self.n_receptors,))
# dg stored unitless (mantissa in nS/ms)
zeros_dg = u.math.zeros(self.varshape + (self.n_receptors,), dtype=V.dtype)
self.V = brainstate.HiddenState(V)
self.w = brainstate.HiddenState(w)
self.dg = brainstate.ShortTermState(zeros_dg)
self.g = brainstate.HiddenState(g)
self.last_spike_time = brainstate.ShortTermState(u.math.full(self.varshape, -1e7 * u.ms))
self.refractory_step_count = brainstate.ShortTermState(u.math.full(self.varshape, 0, dtype=ditype))
self.integration_step = brainstate.ShortTermState.init(braintools.init.Constant(dt), self.varshape)
self.I_stim = brainstate.ShortTermState(u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype))
if self.ref_var:
refractory = braintools.init.param(braintools.init.Constant(False), self.varshape)
self.refractory = brainstate.ShortTermState(refractory)
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Compute surrogate spike output for gradient-based learning.
Applies surrogate gradient function to scaled membrane potential for
differentiable spike generation. Does not modify state variables.
Parameters
----------
V : ArrayLike, optional
Membrane potential (mV). If None, uses current ``self.V.value``.
Returns
-------
ArrayLike
Surrogate spike output in [0, 1], shape ``(*in_size,)``. Produced by
``spk_fun`` applied to ``(V - V_th) / (V_th - V_reset)``.
Notes
-----
This method is primarily used for gradient computation in training contexts.
Actual spike detection during forward simulation uses hard thresholds in
``update`` method.
"""
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
return self.spk_fun(v_scaled)
def _parse_spike_events(self, spike_events: Iterable, v_shape):
r"""Parse incoming spike events into receptor-specific weight array.
Converts event list/dict format into NumPy array with receptor-specific
conductance increments, validating receptor types and weight non-negativity.
Parameters
----------
spike_events : Iterable or None
Spike events as:
- List of ``(receptor_type, weight)`` tuples
- List of dicts with keys ``'receptor_type'``/``'receptor'`` and ``'weight'``
- Single dict (auto-wrapped to list)
- None (returns zero array)
v_shape : tuple
Neuron population shape for broadcasting.
Returns
-------
np.ndarray
Weight array (nS, unitless) with shape ``(*v_shape, n_receptors)``.
Element ``[..., k]`` contains total conductance increment for receptor ``k+1``.
Raises
------
ValueError
If receptor type out of range ``[1, n_receptors]``.
ValueError
If any weight is negative (conductance constraint).
Notes
-----
Receptor types are 1-based (NEST convention). Internal indexing is 0-based.
Multiple events for the same receptor are summed.
"""
dftype = brainstate.environ.dftype()
out = np.zeros(v_shape + (self.n_receptors,), dtype=dftype)
if spike_events is None:
return out
if isinstance(spike_events, dict):
spike_events = [spike_events]
for ev in spike_events:
if isinstance(ev, dict):
receptor = int(ev.get('receptor_type', ev.get('receptor', 1)))
weight = ev.get('weight', 0.0)
else:
receptor, weight = ev
receptor = int(receptor)
if receptor <= 0 or receptor > self.n_receptors:
raise ValueError(f'Receptor type {receptor} out of range [1, {self.n_receptors}].')
w_np = np.asarray(u.math.asarray(weight / u.nS), dtype=dftype)
if np.any(w_np < 0.0):
raise ValueError('Synaptic weights for conductance-based multisynapse models must be non-negative.')
out[..., receptor - 1] += np.broadcast_to(w_np, v_shape)
return out
def _vector_field(self, state, extra):
"""Unit-aware vectorized RHS for all neurons simultaneously.
Parameters
----------
state : DotDict
Keys: V, dg, g, w -- ODE state variables.
``dg`` and ``g`` have an extra trailing receptor dimension.
extra : DotDict
Keys: spike_mask, r, unstable, i_stim, v_peak_detect -- mutable
auxiliary data carried through the integrator.
Returns
-------
DotDict with same keys as ``state``, containing time derivatives.
"""
is_refractory = extra.r > 0
v_eff = u.math.where(is_refractory, self.V_reset, u.math.minimum(state.V, self.V_peak))
# Synaptic current: sum over receptors g_k * (E_rev_k - V)
# v_eff has shape (*varshape,), E_rev has shape (n_receptors,)
# g has shape (*varshape, n_receptors)
# We need to expand v_eff for broadcasting: (*varshape, 1)
v_eff_expanded = u.math.expand_dims(v_eff, axis=-1)
i_syn = u.math.sum(state.g * (self._E_rev_mV - v_eff_expanded), axis=-1)
delta_t_safe = u.math.where(self.Delta_T == 0.0 * u.mV, 1.0 * u.mV, self.Delta_T)
exp_arg = u.math.clip((v_eff - self.V_th) / delta_t_safe, -500.0, 500.0)
i_spike = self.g_L * self.Delta_T * u.math.exp(exp_arg)
dV_raw = (
-self.g_L * (v_eff - self.E_L) + i_spike
+ i_syn - state.w + self.I_e + extra.i_stim
) / self.C_m
dV = u.math.where(is_refractory, u.math.zeros_like(dV_raw), dV_raw)
# Beta synapse dynamics per receptor:
# ddg_k = -dg_k / tau_rise_k
# dg_k_dt = dg_k - g_k / tau_decay_k
ddg = -state.dg / self._tau_rise_ms
dg_dt = state.dg - state.g / self._tau_decay_ms
dw = (self.a * (v_eff - self.E_L) - state.w) / self.tau_w
return DotDict(V=dV, dg=ddg, g=dg_dt, w=dw)
def _event_fn(self, state, extra, accept):
"""In-loop spike detection, reset, and refractory handling.
Parameters
----------
state : DotDict
Keys: V, dg, g, w -- ODE state variables.
extra : DotDict
Keys: spike_mask, r, unstable, i_stim, v_peak_detect.
accept : array, bool
Mask of neurons whose RK substep was accepted.
Returns
-------
(new_state, new_extra) DotDicts with updated spike/reset/refractory info.
"""
unstable = extra.unstable | jnp.any(
accept & ((state.V < -1e3 * u.mV) | (state.w < -1e6 * u.pA) | (state.w > 1e6 * u.pA))
)
refr_accept = accept & (extra.r > 0)
new_V = u.math.where(refr_accept, self.V_reset, state.V)
spike_now = accept & (extra.r <= 0) & (new_V >= extra.v_peak_detect)
spike_mask = extra.spike_mask | spike_now
new_V = u.math.where(spike_now, self.V_reset, new_V)
new_w = u.math.where(spike_now, state.w + self.b, state.w)
r = u.math.where(spike_now & (self.ref_count > 0), self.ref_count + 1, extra.r)
new_state = DotDict({**state, 'V': new_V, 'w': new_w})
new_extra = DotDict({**extra, 'spike_mask': spike_mask, 'r': r, 'unstable': unstable})
return new_state, new_extra
[docs]
def update(self, x=0.0 * u.pA, spike_events=None):
r"""Advance model by one simulation timestep (NEST-compatible update).
Integrates ODEs over :math:`(t, t+dt]` using adaptive RKF45 with
vectorized integration, spike detection, refractory handling, and
receptor-specific spike event application. Follows NEST's update
ordering exactly.
Parameters
----------
x : ArrayLike, optional
Continuous current input (pA), shape broadcastable to ``(*in_size,)``.
Summed with ``current_inputs`` and ``I_e``, then delayed by one timestep
(NEST semantics). Default: 0.0 pA.
spike_events : Iterable or None, optional
Incoming spike events as:
- List of ``(receptor_type, weight)`` tuples
- List of dicts with keys ``'receptor_type'``/``'receptor'`` and ``'weight'``
- Single dict (auto-wrapped to list)
- None (no spike input)
Receptor types are 1-based: ``1 <= receptor_type <= n_receptors``.
Weights (nS) must be non-negative. Default: None.
Returns
-------
ArrayLike
Binary spike indicator (0 or 1), shape ``(*in_size,)``. Float64 for
gradient compatibility. Value is 1.0 if spike occurred during
:math:`(t, t+dt]`, else 0.0.
Raises
------
ValueError
If receptor type out of range ``[1, n_receptors]``.
ValueError
If any spike event weight is negative (conductance constraint).
ValueError
If ``add_delta_input`` stream contains negative values (mapped to receptor 1).
ValueError
If no receptor ports exist but ``delta_inputs`` or ``spike_events`` are non-zero.
ValueError
If numerical instability detected during integration (:math:`V < -1000` mV
or :math:`|w| > 10^6` pA).
Notes
-----
Integration is performed with an adaptive vectorized RKF45 loop,
including in-loop spike/reset/adaptation events and optional
multiple spikes per step. All arithmetic is unit-aware via
``saiunit.math``.
"""
t = brainstate.environ.get('t')
dt = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
ditype = brainstate.environ.ditype()
n_receptors = self.n_receptors
v_shape = self.V.value.shape
# Read state variables with their natural units.
V = self.V.value # mV
dg = self.dg.value * (u.nS / u.ms) # stored unitless, restore nS/ms
g = self.g.value # nS
w = self.w.value # pA
r = self.refractory_step_count.value # int
i_stim = self.I_stim.value # pA
h = self.integration_step.value # ms
# Spike detection threshold: V_peak if Delta_T > 0, else V_th.
v_peak_detect = u.math.where(self.Delta_T > 0.0 * u.mV, self.V_peak, self.V_th)
# Current input for next step (one-step delay).
new_i_stim = self.sum_current_inputs(x, self.V.value) # pA
# Parse spike events into per-receptor weight array.
w_by_rec = self._parse_spike_events(spike_events, v_shape)
# Default delta input mapped to receptor 1.
# Use jnp.asarray (not np.asarray) so this path is JIT-compatible inside
# brainstate.transform.for_loop, where sum_delta_inputs may return a tracer.
w_default = u.get_mantissa(u.math.asarray(self.sum_delta_inputs(0.0 * u.nS) / u.nS))
w_default = jnp.broadcast_to(jnp.asarray(w_default, dtype=dftype), v_shape)
if n_receptors > 0:
# Guard with is_tracer: concrete values support Python-level ValueError;
# traced values (inside JIT) skip the eager check safely.
if not is_tracer(w_default) and np.any(np.asarray(w_default) < 0.0):
raise ValueError('Synaptic weights for conductance-based multisynapse models must be non-negative.')
# Use JAX immutable update so w_by_rec stays JIT-compatible.
w_by_rec = jnp.asarray(w_by_rec).at[..., 0].add(w_default)
elif not is_tracer(w_default) and np.any(np.asarray(w_default) != 0.0):
raise ValueError('No receptor ports available for incoming spike conductance.')
# Beta normalization factors (unitless, per receptor).
g0 = np.broadcast_to(self._g0, v_shape + (n_receptors,))
# Adaptive RKF45 integration via generic integrator.
ode_state = DotDict(V=V, dg=dg, g=g, w=w)
extra = DotDict(
spike_mask=jnp.zeros(self.varshape, dtype=jnp.bool_),
r=r,
unstable=jnp.array(False),
i_stim=i_stim,
v_peak_detect=v_peak_detect,
)
ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra)
V, dg, g, w = ode_state.V, ode_state.dg, ode_state.g, ode_state.w
spike_mask, r, unstable = extra.spike_mask, extra.r, extra.unstable
# Post-loop stability check.
brainstate.transform.jit_error_if(
jnp.any(unstable), 'Numerical instability in aeif_cond_beta_multisynapse dynamics.'
)
# Decrement refractory counter.
r = u.math.where(r > 0, r - 1, r)
# Apply incoming spike events to dg states with beta normalization.
# g0 has shape (*v_shape, n_receptors), w_by_rec has shape (*v_shape, n_receptors)
# g0 * w_by_rec gives nS (unitless), need to convert to nS/ms for dg units
# In NEST beta multisynapse: dg_k += g0_k * w_k where g0 has units 1/ms
# so g0 * w (nS) gives nS/ms
dg_increment = jnp.asarray(g0 * w_by_rec) * (u.nS / u.ms)
dg = dg + dg_increment
# Write back state.
self.V.value = V
self.dg.value = u.get_mantissa(dg) # store unitless mantissa
self.g.value = g
self.w.value = w
self.refractory_step_count.value = jnp.asarray(u.get_mantissa(r), dtype=ditype)
self.integration_step.value = h
self.I_stim.value = new_i_stim + u.math.zeros(self.varshape) * u.pA
last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value)
self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time)
if self.ref_var:
self.refractory.value = jax.lax.stop_gradient(self.refractory_step_count.value > 0)
return u.math.asarray(spike_mask, dtype=dftype)