# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
from typing import Callable
import brainstate
import braintools
import saiunit as u
import numpy as np
import jax.numpy as jnp
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
from ._utils import is_tracer
__all__ = [
'siegert_neuron',
]
try:
from scipy import integrate as _sp_integrate
from scipy import special as _sp_special
_HAVE_SCIPY = True
except Exception: # pragma: no cover - fallback path when SciPy is unavailable.
_HAVE_SCIPY = False
# Gauss-Legendre nodes used by the scalar quadrature helpers.
_GAUSS_NODES, _GAUSS_WEIGHTS = np.polynomial.legendre.leggauss(64)
class siegert_neuron(NESTNeuron):
r"""NEST-compatible ``siegert_neuron`` mean-field rate model.
**1. Overview**
Mean-field rate model using the Siegert gain function of a noisy LIF neuron.
This model computes the population-averaged firing rate from drift-diffusion
input statistics via an analytic transfer function, enabling efficient
large-scale network simulation without explicit spike generation.
**2. Mathematical Description**
The rate dynamics follow a first-order ODE:
.. math::
\tau\,\frac{dr(t)}{dt} = -r(t) + \text{mean} + \Phi(\mu, \sigma^2),
where:
- :math:`r(t)` is the population firing rate (Hz)
- :math:`\tau` is the rate time constant
- :math:`\text{mean}` is a constant baseline drive
- :math:`\Phi(\mu, \sigma^2)` is the Siegert transfer function
- :math:`\mu` is the total drift input (mean membrane potential shift)
- :math:`\sigma^2` is the total diffusion input (variance)
The Siegert function :math:`\Phi` analytically computes the steady-state
firing rate of a leaky integrate-and-fire neuron receiving white noise with
drift :math:`\mu` and diffusion :math:`\sigma^2`, subject to threshold
:math:`\theta`, reset :math:`V_{\text{reset}}`, refractory period
:math:`t_{\text{ref}}`, and membrane time constant :math:`\tau_m` [2]_.
For colored noise (finite :math:`\tau_{\text{syn}} > 0`), a threshold shift
correction is applied [3]_:
.. math::
\Delta_{\text{th}} = \frac{\alpha}{2} \sqrt{\frac{\tau_{\text{syn}}}{\tau_m}},
where :math:`\alpha = |\zeta(1/2)| \sqrt{2} \approx 2.0653`.
The integration is performed via exact exponential propagators:
.. math::
r(t + \Delta t) = e^{-\Delta t / \tau} r(t) + \left(1 - e^{-\Delta t / \tau}\right)
\left(\text{mean} + \Phi(\mu, \sigma^2)\right).
**3. NEST-Compatible Update Ordering (Non-WFR Path)**
For each simulation step:
1. Collect delayed and instantaneous diffusion-event buffers from queues.
2. Sum all drift and diffusion contributions (delayed, instant, direct inputs).
3. Evaluate Siegert transfer function :math:`\Phi(\mu_{\text{total}}, \sigma^2_{\text{total}})`.
4. Update rate via exact exponential step: :math:`r \leftarrow P_1 r + P_2 (\text{mean} + \Phi)`.
5. Publish updated rate to ``delayed_rate`` and ``instant_rate`` buffers for outgoing connections.
This mirrors NEST's non-waveform-relaxation ``update_`` semantics where
emitted diffusion coefficients are overwritten with the post-update rate.
**4. Diffusion Event Handling**
Runtime diffusion events modulate drift and diffusion inputs. Events can be
supplied via two channels:
- ``instant_diffusion_events``: applied in the current step (delay = 0)
- ``delayed_diffusion_events``: scheduled by integer ``delay_steps`` (default 1)
Event format supports dicts, tuples, or lists. Dict keys:
- ``coeff`` (or ``rate``/``value``): base coefficient
- ``drift_factor``: multiplier for drift contribution
- ``diffusion_factor``: multiplier for diffusion contribution
- ``weight``: connection weight (default 1)
- ``multiplicity``: event count (default 1)
- ``delay_steps`` (or ``delay``): integer delay in steps
Tuple/list format: ``(coeff, drift_factor, diffusion_factor, delay_steps, weight, multiplicity)``.
Shorter tuples use default values for trailing fields.
Drift and diffusion contributions are computed as:
.. math::
\mu &= \text{coeff} \times \text{weight} \times \text{multiplicity} \times \text{drift\_factor}, \\
\sigma^2 &= \text{coeff} \times \text{weight} \times \text{multiplicity} \times \text{diffusion\_factor}.
**5. Siegert Transfer Function Computation**
The Siegert function is evaluated element-wise for array inputs. For each
population element, the computation handles three regimes:
- **Deterministic (σ² ≤ 0)**: If μ > θ, returns LIF firing rate; else 0.
- **Very subthreshold (θ - μ > 6σ)**: Returns 0 (Brunel 2000 fast path).
- **General diffusive**: Computes via integral of scaled complementary error
function (erfcx) and Dawson's integral, using either SciPy (if available)
or custom Gauss-Legendre quadrature with asymptotic expansions.
Numerical integration uses 64-point Gauss-Legendre quadrature for erfcx and
adaptive segmentation for Dawson's integral, ensuring relative accuracy
~1.5e-8.
Parameters
----------
in_size : Size
Population shape. Tuple of ints or single int for 1D populations.
Determines the spatial structure of the rate model. For example,
``(10, 10)`` creates a 10×10 grid of mean-field neurons.
tau : Quantity[ms], optional
Time constant of the first-order rate dynamics (must be > 0). Controls
the rate of convergence to the steady-state Siegert value. Smaller
values produce faster tracking of input changes. Default: ``1 ms``.
tau_m : Quantity[ms], optional
Membrane time constant used in the Siegert gain function (must be > 0).
Represents the passive membrane time constant of the modeled LIF neurons.
Default: ``5 ms``.
tau_syn : Quantity[ms], optional
Synaptic time constant for colored-noise threshold correction (must be ≥ 0).
When ``tau_syn > 0``, applies a threshold shift to account for finite
synaptic rise time [3]_. Use ``0 ms`` for white noise (no correction).
Default: ``0 ms``.
t_ref : Quantity[ms], optional
Refractory period in the Siegert gain function (must be ≥ 0). Represents
the absolute refractory period during which the neuron cannot spike.
Increases the interspike interval and reduces firing rates. Default: ``2 ms``.
mean : float, optional
Constant additive baseline drive in the rate ODE (dimensionless). Shifts
the firing rate upward without affecting dynamics. Can be scalar or
array matching ``in_size``. Default: ``0.0``.
theta : float, optional
Spike threshold relative to resting potential (dimensionless, corresponds
to mV in NEST). Must be > ``V_reset``. Defines the firing threshold in
the Siegert transfer function. Default: ``15.0``.
V_reset : float, optional
Reset potential relative to resting potential (dimensionless, corresponds
to mV in NEST). Must be < ``theta``. Neuron is reset to this value after
spiking in the underlying LIF model. Default: ``0.0``.
rate_initializer : Callable, optional
Initializer function for the ``rate`` state variable. Called as
``rate_initializer(shape, batch_size)`` during ``init_state()``. Default:
``braintools.init.Constant(0.0)`` (all neurons start at 0 Hz).
name : str, optional
Unique identifier for this module. If ``None``, auto-generated. Used for
logging and debugging.
Parameter Mapping
-----------------
========================= ========================== ==================================
NEST Parameter brainpy.state Parameter Description
========================= ========================== ==================================
``tau`` ``tau`` Rate dynamics time constant
``tau_m`` ``tau_m`` Membrane time constant (Siegert)
``tau_syn`` ``tau_syn`` Synaptic time constant (threshold shift)
``t_ref`` ``t_ref`` Refractory period (Siegert)
``mean`` ``mean`` Constant baseline drive
``theta`` ``theta`` Spike threshold (relative to rest)
``V_reset`` ``V_reset`` Reset potential (relative to rest)
``rate`` ``rate`` Current firing rate (Hz)
========================= ========================== ==================================
Raises
------
ValueError
If ``tau`` ≤ 0, ``tau_m`` ≤ 0, ``tau_syn`` < 0, ``t_ref`` < 0, or ``V_reset`` ≥ ``theta``.
ValueError
If ``instant_diffusion_events`` contains non-zero ``delay_steps``.
ValueError
If ``delayed_diffusion_events`` contains negative ``delay_steps``.
ValueError
If event tuples have length > 6 or < 1.
Notes
-----
**Computational Complexity**
- Siegert evaluation is the primary bottleneck (O(N) per neuron).
- Without SciPy, custom quadrature adds ~10× overhead.
- Delayed event queues are sparse dicts (O(1) insertion, O(K) retrieval
for K active delays).
**Numerical Stability:**
- Uses ``erfcx(x) = exp(x²) erfc(x)`` to avoid overflow for large x.
- Asymptotic expansions for erfcx and Dawson's integral when x > 8.
- Exact exponential propagators (``exp`` and ``expm1``) prevent drift accumulation.
**Batch Dimensions:**
States support an optional leading batch dimension for parallelizing multiple
network realizations. Initialize with ``init_state(batch_size=B)`` to create
shape ``(B, *in_size)``.
**Integration with NEST:**
This implementation reproduces NEST 3.9+ behavior for ``siegert_neuron`` in
non-waveform-relaxation mode. Key differences:
- NEST uses precise spike times; brainpy.state uses fixed-step integration.
- NEST's WFR mode (iterative delay resolution) is not implemented.
- Event formats are compatible but may differ in edge cases (consult NEST docs).
References
----------
.. [1] Hahne J, Dahmen D, Schuecker J, Frommer A, Bolten M, Helias M,
Diesmann M (2017). Integration of continuous-time dynamics in a spiking
neural network simulator. Frontiers in Neuroinformatics, 11:34.
DOI: ``10.3389/fninf.2017.00034``.
.. [2] Fourcaud N, Brunel N (2002). Dynamics of the firing probability of
noisy integrate-and-fire neurons. Neural Computation, 14(9):2057-2110.
DOI: ``10.1162/089976602320264015``.
.. [3] Schuecker J, Diesmann M, Helias M (2015). Modulated escape from a
metastable state driven by colored noise. Physical Review E, 92:052119.
DOI: ``10.1103/PhysRevE.92.052119``.
Examples
--------
**Basic usage with constant input:**
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> model = bp.siegert_neuron(in_size=10, tau=2*u.ms, tau_m=10*u.ms)
>>> model.init_all_states()
>>> with brainstate.environ.context(dt=0.1*u.ms):
... for _ in range(100):
... rate = model.update(drift_input=12.0, diffusion_input=4.0)
>>> print(rate) # Steady-state firing rate in Hz
**Using diffusion events for network coupling:**
.. code-block:: python
>>> model.init_all_states()
>>> event = {'coeff': 50.0, 'drift_factor': 0.1, 'diffusion_factor': 0.05, 'delay_steps': 1}
>>> with brainstate.environ.context(dt=0.1*u.ms):
... rate = model.update(delayed_diffusion_events=event)
>>> print(model.rate.value) # Rate after delayed event delivery
**Mean-field network with recurrent connections:**
.. code-block:: python
>>> exc = bp.siegert_neuron(in_size=800, tau=1*u.ms, theta=15.0)
>>> inh = bp.siegert_neuron(in_size=200, tau=1*u.ms, theta=15.0)
>>> exc.init_all_states()
>>> inh.init_all_states()
>>> # Simulate recurrent network (conceptual; requires projection setup)
>>> with brainstate.environ.context(dt=0.1*u.ms):
... for t in range(1000):
... exc_drive = exc.rate.value.sum() * 0.01
... inh_drive = inh.rate.value.sum() * -0.02
... exc.update(drift_input=exc_drive + inh_drive, diffusion_input=2.0)
... inh.update(drift_input=exc_drive, diffusion_input=1.0)
"""
__module__ = 'brainpy.state'
# NEST value: alpha = |zeta(1/2)| * sqrt(2)
_ALPHA = 2.0652531522312172
def __init__(
self,
in_size: Size,
tau: ArrayLike = 1.0 * u.ms,
tau_m: ArrayLike = 5.0 * u.ms,
tau_syn: ArrayLike = 0.0 * u.ms,
t_ref: ArrayLike = 2.0 * u.ms,
mean: ArrayLike = 0.0,
theta: ArrayLike = 15.0,
V_reset: ArrayLike = 0.0,
rate_initializer: Callable = braintools.init.Constant(0.0),
name: str = None,
):
super().__init__(in_size=in_size, name=name)
self.tau = braintools.init.param(tau, self.varshape)
self.tau_m = braintools.init.param(tau_m, self.varshape)
self.tau_syn = braintools.init.param(tau_syn, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.mean = braintools.init.param(mean, self.varshape)
self.theta = braintools.init.param(theta, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.rate_initializer = rate_initializer
self._delayed_drift_queue = {}
self._delayed_diffusion_queue = {}
self._validate_parameters()
@property
def recordables(self):
return ['rate']
@property
def receptor_types(self):
# NEST handles DiffusionConnectionEvent via receptor type 1.
return {'DIFFUSION': 1}
@staticmethod
def _to_numpy(x):
dftype = brainstate.environ.dftype()
return np.asarray(u.get_mantissa(x), dtype=dftype)
@staticmethod
def _to_numpy_ms(x):
dftype = brainstate.environ.dftype()
return np.asarray(u.get_mantissa(x / u.ms), dtype=dftype)
@staticmethod
def _broadcast_to_state(x_np: np.ndarray, shape):
return np.broadcast_to(x_np, shape)
@staticmethod
def _to_int_scalar(x, name: str):
dftype = brainstate.environ.dftype()
arr = np.asarray(u.get_mantissa(x), dtype=dftype).reshape(-1)
if arr.size != 1:
raise ValueError(f'{name} must be scalar.')
return int(arr[0])
@staticmethod
def _coerce_events(events):
if events is None:
return []
if isinstance(events, dict):
return [events]
if isinstance(events, tuple):
if len(events) == 0:
return []
if isinstance(events[0], (dict, tuple, list)):
return list(events)
return [events]
if isinstance(events, list):
if len(events) == 0:
return []
if isinstance(events[0], (dict, tuple, list)):
return events
return [tuple(events)]
return [events]
@staticmethod
def _queue_add(queue: dict, step_idx: int, value: np.ndarray):
if step_idx in queue:
queue[step_idx] = queue[step_idx] + value
else:
dftype = brainstate.environ.dftype()
queue[step_idx] = np.array(value, dtype=dftype, copy=True)
def _drain_delayed_queue(self, step_idx: int, state_shape):
drift = self._delayed_drift_queue.pop(step_idx, None)
diffusion = self._delayed_diffusion_queue.pop(step_idx, None)
dftype = brainstate.environ.dftype()
if drift is None:
drift = np.zeros(state_shape, dtype=dftype)
else:
drift = np.array(self._broadcast_to_state(np.asarray(drift, dtype=dftype), state_shape), copy=True)
if diffusion is None:
diffusion = np.zeros(state_shape, dtype=dftype)
else:
diffusion = np.array(
self._broadcast_to_state(np.asarray(diffusion, dtype=dftype), state_shape),
copy=True,
)
return drift, diffusion
def _extract_event_fields(self, ev, default_delay_steps: int):
if isinstance(ev, dict):
coeff = ev.get('coeff', ev.get('rate', ev.get('value', 0.0)))
drift_factor = ev.get('drift_factor', 1.0)
diffusion_factor = ev.get('diffusion_factor', 0.0)
weight = ev.get('weight', 1.0)
multiplicity = ev.get('multiplicity', 1.0)
delay_steps = ev.get('delay_steps', ev.get('delay', default_delay_steps))
elif isinstance(ev, (tuple, list)):
if len(ev) == 1:
coeff = ev[0]
drift_factor = 1.0
diffusion_factor = 0.0
weight = 1.0
multiplicity = 1.0
delay_steps = default_delay_steps
elif len(ev) == 2:
coeff, drift_factor = ev
diffusion_factor = 0.0
weight = 1.0
multiplicity = 1.0
delay_steps = default_delay_steps
elif len(ev) == 3:
coeff, drift_factor, diffusion_factor = ev
weight = 1.0
multiplicity = 1.0
delay_steps = default_delay_steps
elif len(ev) == 4:
coeff, drift_factor, diffusion_factor, delay_steps = ev
weight = 1.0
multiplicity = 1.0
elif len(ev) == 5:
coeff, drift_factor, diffusion_factor, delay_steps, weight = ev
multiplicity = 1.0
elif len(ev) == 6:
coeff, drift_factor, diffusion_factor, delay_steps, weight, multiplicity = ev
else:
raise ValueError('Diffusion event tuples must have length 1 to 6.')
else:
coeff = ev
drift_factor = 1.0
diffusion_factor = 0.0
weight = 1.0
multiplicity = 1.0
delay_steps = default_delay_steps
delay_steps = self._to_int_scalar(delay_steps, name='delay_steps')
return coeff, drift_factor, diffusion_factor, weight, multiplicity, delay_steps
def _event_to_drift_diffusion(self, ev, default_delay_steps: int, state_shape):
coeff, drift_factor, diffusion_factor, weight, multiplicity, delay_steps = self._extract_event_fields(
ev,
default_delay_steps,
)
coeff_np = self._broadcast_to_state(self._to_numpy(coeff), state_shape)
drift_factor_np = self._broadcast_to_state(self._to_numpy(drift_factor), state_shape)
diffusion_factor_np = self._broadcast_to_state(self._to_numpy(diffusion_factor), state_shape)
weight_np = self._broadcast_to_state(self._to_numpy(weight), state_shape)
multiplicity_np = self._broadcast_to_state(self._to_numpy(multiplicity), state_shape)
weighted_coeff = coeff_np * weight_np * multiplicity_np
drift = drift_factor_np * weighted_coeff
diffusion = diffusion_factor_np * weighted_coeff
return drift, diffusion, delay_steps
def _accumulate_instant_events(self, events, state_shape):
dftype = brainstate.environ.dftype()
drift = np.zeros(state_shape, dtype=dftype)
diffusion = np.zeros(state_shape, dtype=dftype)
for ev in self._coerce_events(events):
d_i, s_i, delay_steps = self._event_to_drift_diffusion(
ev,
default_delay_steps=0,
state_shape=state_shape,
)
if delay_steps != 0:
raise ValueError('instant_diffusion_events must not specify non-zero delay_steps.')
drift += d_i
diffusion += s_i
return drift, diffusion
def _schedule_delayed_events(self, events, step_idx: int, state_shape):
dftype = brainstate.environ.dftype()
drift_now = np.zeros(state_shape, dtype=dftype)
diffusion_now = np.zeros(state_shape, dtype=dftype)
for ev in self._coerce_events(events):
d_i, s_i, delay_steps = self._event_to_drift_diffusion(
ev,
default_delay_steps=1,
state_shape=state_shape,
)
if delay_steps < 0:
raise ValueError('delay_steps for delayed_diffusion_events must be >= 0.')
if delay_steps == 0:
drift_now += d_i
diffusion_now += s_i
else:
target_step = step_idx + delay_steps
self._queue_add(self._delayed_drift_queue, target_step, d_i)
self._queue_add(self._delayed_diffusion_queue, target_step, s_i)
return drift_now, diffusion_now
def _validate_parameters(self):
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.tau, self.tau_m, self.tau_syn, self.t_ref, self.V_reset, self.theta)):
return
if np.any(self.tau <= 0.0 * u.ms):
raise ValueError('Time constant tau must be > 0.')
if np.any(self.tau_m <= 0.0 * u.ms):
raise ValueError('Membrane time constant tau_m must be > 0.')
if np.any(self.tau_syn < 0.0 * u.ms):
raise ValueError('Synaptic time constant tau_syn must be >= 0.')
if np.any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory period t_ref must be >= 0.')
if np.any(self.V_reset >= self.theta):
raise ValueError('Reset potential V_reset must be smaller than threshold theta.')
[docs]
def init_state(self, **kwargs):
rate = braintools.init.param(self.rate_initializer, self.varshape)
rate_np = self._to_numpy(rate)
self.rate = brainstate.ShortTermState(rate_np)
dftype = brainstate.environ.dftype()
self.instant_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True))
self.delayed_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True))
ditype = brainstate.environ.ditype()
self._step_count = brainstate.ShortTermState(np.asarray(0, dtype=ditype))
self._delayed_drift_queue = {}
self._delayed_diffusion_queue = {}
@staticmethod
def _gauss_legendre_scalar_integral(func, a: float, b: float):
mid = 0.5 * (a + b)
half = 0.5 * (b - a)
pts = mid + half * _GAUSS_NODES
dftype = brainstate.environ.dftype()
vals = np.asarray([func(float(x)) for x in pts], dtype=dftype)
return float(half * np.sum(_GAUSS_WEIGHTS * vals))
@staticmethod
def _erfcx_pos_scalar(x: float):
if _HAVE_SCIPY:
return float(_sp_special.erfcx(x))
if x < 25.0:
return math.exp(x * x) * math.erfc(x)
inv = 1.0 / x
inv2 = inv * inv
poly = 1.0 + 0.5 * inv2 + 0.75 * inv2 * inv2 + 1.875 * inv2 ** 3 + 6.5625 * inv2 ** 4
return (inv / math.sqrt(math.pi)) * poly
@staticmethod
def _integral_erfcx_asympt(a: float, b: float):
inv_a2 = 1.0 / (a * a)
inv_b2 = 1.0 / (b * b)
term0 = math.log(b / a)
term1 = -0.25 * (inv_b2 - inv_a2)
term2 = -(3.0 / 16.0) * (inv_b2 * inv_b2 - inv_a2 * inv_a2)
term3 = -(5.0 / 16.0) * (inv_b2 ** 3 - inv_a2 ** 3)
term4 = -(105.0 / 128.0) * (inv_b2 ** 4 - inv_a2 ** 4)
return (term0 + term1 + term2 + term3 + term4) / math.sqrt(math.pi)
@classmethod
def _integral_erfcx_pos(cls, a: float, b: float):
if a == b:
return 0.0
sign = 1.0
lo = float(a)
hi = float(b)
if lo > hi:
sign = -1.0
lo, hi = hi, lo
if _HAVE_SCIPY:
result, _ = _sp_integrate.quad(
lambda s: float(_sp_special.erfcx(s)),
lo,
hi,
epsabs=0.0,
epsrel=1.49e-8,
limit=1000,
)
return sign * float(result)
split = 8.0
total = 0.0
if lo < split:
hi_num = min(hi, split)
width = hi_num - lo
nseg = max(1, int(math.ceil(width / 2.0)))
seg_w = width / nseg
left = lo
for _ in range(nseg):
right = left + seg_w
total += cls._gauss_legendre_scalar_integral(cls._erfcx_pos_scalar, left, right)
left = right
if hi > split:
lo_as = max(lo, split)
total += cls._integral_erfcx_asympt(lo_as, hi)
return sign * total
@classmethod
def _dawsn_pos_scalar(cls, x: float):
if _HAVE_SCIPY:
return float(_sp_special.dawsn(x))
if x == 0.0:
return 0.0
if x < 0.2:
x2 = x * x
return x * (
1.0
- (2.0 / 3.0) * x2
+ (4.0 / 15.0) * x2 * x2
- (8.0 / 105.0) * x2 ** 3
+ (16.0 / 945.0) * x2 ** 4
)
if x >= 8.0:
inv = 1.0 / x
inv2 = inv * inv
return (
0.5 * inv
+ 0.25 * inv * inv2
+ (3.0 / 8.0) * inv * inv2 ** 2
+ (15.0 / 16.0) * inv * inv2 ** 3
+ (105.0 / 32.0) * inv * inv2 ** 4
)
# Dawson(x) = exp(-x^2) * integral_0^x exp(t^2) dt
nseg = max(1, int(math.ceil(x / 1.0)))
seg_w = x / nseg
left = 0.0
integral = 0.0
for _ in range(nseg):
right = left + seg_w
integral += cls._gauss_legendre_scalar_integral(lambda t: math.exp(t * t), left, right)
left = right
return math.exp(-x * x) * integral
@classmethod
def _siegert_scalar(
cls,
mu: float,
sigma_square: float,
tau_m_ms: float,
tau_syn_ms: float,
t_ref_ms: float,
theta: float,
v_reset: float,
):
if sigma_square <= 0.0:
if mu > theta:
return 1e3 / (t_ref_ms + tau_m_ms * math.log((mu - v_reset) / (mu - theta)))
return 0.0
sigma = math.sqrt(sigma_square)
# NEST fast path for very subthreshold input (Brunel 2000, eq. 22 estimate).
if (theta - mu) > 6.0 * sigma:
return 0.0
threshold_shift = (cls._ALPHA / 2.0) * math.sqrt(tau_syn_ms / tau_m_ms)
y_th = (theta - mu) / sigma + threshold_shift
y_r = (v_reset - mu) / sigma + threshold_shift
sqrt_pi = math.sqrt(math.pi)
if y_r > 0.0:
result = cls._integral_erfcx_pos(y_r, y_th)
integral = (
2.0 * cls._dawsn_pos_scalar(y_th)
- 2.0 * math.exp(y_r * y_r - y_th * y_th) * cls._dawsn_pos_scalar(y_r)
- math.exp(-y_th * y_th) * result
)
e = math.exp(-y_th * y_th)
return 1e3 * e / (e * t_ref_ms + tau_m_ms * sqrt_pi * integral)
if y_th < 0.0:
integral = cls._integral_erfcx_pos(-y_th, -y_r)
return 1e3 / (t_ref_ms + tau_m_ms * sqrt_pi * integral)
result = cls._integral_erfcx_pos(y_th, -y_r)
integral = 2.0 * cls._dawsn_pos_scalar(y_th) + math.exp(-y_th * y_th) * result
e = math.exp(-y_th * y_th)
return 1e3 * e / (e * t_ref_ms + tau_m_ms * sqrt_pi * integral)
@classmethod
def _siegert_array(
cls,
mu: np.ndarray,
sigma_square: np.ndarray,
tau_m_ms: np.ndarray,
tau_syn_ms: np.ndarray,
t_ref_ms: np.ndarray,
theta: np.ndarray,
v_reset: np.ndarray,
):
dftype = brainstate.environ.dftype()
out = np.empty_like(mu, dtype=dftype)
for idx in np.ndindex(mu.shape):
out[idx] = cls._siegert_scalar(
float(mu[idx]),
float(sigma_square[idx]),
float(tau_m_ms[idx]),
float(tau_syn_ms[idx]),
float(t_ref_ms[idx]),
float(theta[idx]),
float(v_reset[idx]),
)
return out
[docs]
def siegert_rate(self, mu: ArrayLike, sigma_square: ArrayLike):
r"""Evaluate the NEST-compatible Siegert transfer function.
Computes the steady-state firing rate :math:`\Phi(\mu, \sigma^2)` of a
noisy LIF neuron with drift ``mu`` and diffusion ``sigma_square``, using
the analytic Siegert formula [2]_ with optional colored-noise correction [3]_.
The computation is vectorized over population elements. Inputs are broadcast
with model parameters (``theta``, ``tau_m``, etc.) to produce an output
array matching the broadcast shape.
Parameters
----------
mu : ArrayLike
Drift input (mean membrane potential shift, dimensionless). Scalar or
array broadcastable with ``sigma_square`` and model parameters. Positive
values depolarize the neuron. Typically in the range [0, 30] for
physiological parameters.
sigma_square : ArrayLike
Diffusion input (membrane potential variance, dimensionless squared).
Scalar or array broadcastable with ``mu`` and model parameters. Must be
non-negative. Typical values: 0.1–10 for moderate noise. Zero produces
deterministic LIF behavior.
Returns
-------
rate : ndarray
Firing rate in Hz (shape matches broadcast of inputs and model parameters).
Values ≥ 0. Returns 0 for subthreshold inputs (μ < θ with low noise).
Maximum rate is approximately ``1000 / t_ref`` Hz (refractory limit).
Notes
-----
**Special Cases:**
- If ``sigma_square`` ≤ 0: deterministic LIF (returns 0 if μ ≤ θ, else
fires at constant-input rate).
- If (θ - μ) > 6σ: deep subthreshold (returns 0, Brunel 2000 fast path).
- If ``t_ref`` = 0: no refractory limit (rate can diverge for μ >> θ).
**Performance:**
- Without SciPy: uses 64-point Gauss-Legendre quadrature (~10× slower).
- With SciPy: uses ``scipy.integrate.quad`` and ``scipy.special`` (faster).
**Broadcasting Rules:**
Output shape is ``np.broadcast(mu, sigma_square, theta).shape``. For
example, if model has ``in_size=(10,)``, ``mu`` is scalar, and
``sigma_square`` has shape ``(10,)``, output shape is ``(10,)``.
Examples
--------
**Single neuron with varying drift:**
.. code-block:: python
>>> import brainpy.state as bp
>>> import numpy as np
>>> import saiunit as u
>>> model = bp.siegert_neuron(in_size=1, tau_m=10*u.ms, t_ref=2*u.ms, theta=15.0)
>>> mu_vals = np.linspace(0, 25, 50)
>>> rates = model.siegert_rate(mu=mu_vals, sigma_square=2.0)
>>> print(rates.shape) # (50,)
>>> print(rates.max()) # Maximum firing rate in Hz
**Population with heterogeneous noise:**
.. code-block:: python
>>> model = bp.siegert_neuron(in_size=100, tau_m=10*u.ms)
>>> sigma_sq = np.linspace(0.1, 5.0, 100)
>>> rates = model.siegert_rate(mu=15.0, sigma_square=sigma_sq)
>>> print(rates.shape) # (100,)
**2D grid with spatially varying input:**
.. code-block:: python
>>> model = bp.siegert_neuron(in_size=(10, 10), tau_m=10*u.ms)
>>> mu_grid = np.random.uniform(10, 20, size=(10, 10))
>>> rates = model.siegert_rate(mu=mu_grid, sigma_square=3.0)
>>> print(rates.shape) # (10, 10)
"""
mu_np = self._to_numpy(mu)
sigma_np = self._to_numpy(sigma_square)
state_shape = np.broadcast(
mu_np,
sigma_np,
self._to_numpy(self.theta),
).shape
mu_b = self._broadcast_to_state(mu_np, state_shape)
sigma_b = self._broadcast_to_state(sigma_np, state_shape)
tau_m_b = self._broadcast_to_state(self._to_numpy_ms(self.tau_m), state_shape)
tau_syn_b = self._broadcast_to_state(self._to_numpy_ms(self.tau_syn), state_shape)
t_ref_b = self._broadcast_to_state(self._to_numpy_ms(self.t_ref), state_shape)
theta_b = self._broadcast_to_state(self._to_numpy(self.theta), state_shape)
v_reset_b = self._broadcast_to_state(self._to_numpy(self.V_reset), state_shape)
return self._siegert_array(
mu_b,
sigma_b,
tau_m_b,
tau_syn_b,
t_ref_b,
theta_b,
v_reset_b,
)
[docs]
def update(
self,
x=0.0,
drift_input: ArrayLike = 0.0,
diffusion_input: ArrayLike = 0.0,
instant_diffusion_events=None,
delayed_diffusion_events=None,
_precomputed_drive=None,
):
r"""Advance the rate dynamics by one simulation timestep.
Integrates the first-order rate ODE using exact exponential propagators,
incorporating drift/diffusion inputs from multiple sources (direct inputs,
current/delta hooks, and diffusion events). Updates internal state variables
and publishes the new rate to ``delayed_rate`` and ``instant_rate`` buffers
for outgoing connections.
**Update Sequence:**
1. Retrieve timestep ``dt`` from ``brainstate.environ``.
2. Drain delayed event queues for the current step index.
3. Schedule incoming delayed events into future queue slots.
4. Accumulate instant events (must have ``delay_steps=0``).
5. Sum all drift and diffusion contributions:
- Delayed events (from queue)
- Scheduled delayed events with ``delay_steps=0``
- Instant events
- Direct inputs (``drift_input``, ``diffusion_input``)
- Dynamics hooks (``current_inputs``, ``delta_inputs``)
6. Evaluate Siegert transfer function :math:`\Phi(\mu_{\text{total}}, \sigma^2_{\text{total}})`.
7. Update rate: :math:`r \leftarrow P_1 r + P_2 (\text{mean} + \Phi)`,
where :math:`P_1 = e^{-\Delta t / \tau}` and :math:`P_2 = 1 - P_1`.
8. Copy new rate to ``delayed_rate`` and ``instant_rate`` (NEST non-WFR semantics).
9. Increment internal step counter.
Parameters
----------
x : ArrayLike, optional
External input passed to ``sum_current_inputs()`` hook (dimensionless).
Scalar or array broadcastable to ``in_size``. Used for compatibility
with standard Dynamics input API. Default: ``0.0``.
drift_input : ArrayLike, optional
Direct drift contribution (dimensionless). Scalar or array broadcastable
to ``in_size``. Added to total drift before Siegert evaluation. Positive
values increase firing rate. Default: ``0.0``.
diffusion_input : ArrayLike, optional
Direct diffusion contribution (dimensionless squared). Scalar or array
broadcastable to ``in_size``. Added to total diffusion (variance) before
Siegert evaluation. Must be non-negative. Default: ``0.0``.
instant_diffusion_events : None, dict, tuple, list, optional
Diffusion events applied in the current step (delay = 0). Can be:
- ``None``: no events
- Single dict: ``{'coeff': float, 'drift_factor': float, ...}``
- Tuple/list of event dicts
- Tuple of (coeff, drift_factor, diffusion_factor, ...)
All events must have ``delay_steps=0`` (implicit or explicit). Raises
``ValueError`` if non-zero delay is specified. Default: ``None``.
delayed_diffusion_events : None, dict, tuple, list, optional
Diffusion events scheduled for future delivery. Format identical to
``instant_diffusion_events``, but ``delay_steps`` can be any non-negative
integer (default 1). Events with ``delay_steps=0`` are applied immediately.
Negative delays raise ``ValueError``. Default: ``None``.
Returns
-------
rate : ndarray
Updated firing rate in Hz (shape matches ``in_size`` or ``(batch_size, *in_size)``).
Also stored in ``self.rate.value``. Values are non-negative.
Raises
------
ValueError
If ``instant_diffusion_events`` contains events with ``delay_steps != 0``.
ValueError
If ``delayed_diffusion_events`` contains events with ``delay_steps < 0``.
ValueError
If event tuples have invalid length (must be 1–6 elements).
Notes
-----
**State Updates:**
The following state variables are modified in-place:
- ``self.rate``: current firing rate (Hz)
- ``self.delayed_rate``: rate for delayed connections (copy of ``rate``)
- ``self.instant_rate``: rate for instant connections (copy of ``rate``)
- ``self._step_count``: internal step counter (int64)
Event queues (``_delayed_drift_queue``, ``_delayed_diffusion_queue``) are
updated: delivered events are removed, new events are added.
**Numerical Properties:**
- **Exact integration**: exponential propagators ensure no drift accumulation.
- **Stability**: unconditionally stable for all ``tau > 0`` and ``dt > 0``.
- **Precision**: limited by Siegert evaluation accuracy (~1.5e-8 relative error).
**Broadcasting:**
All inputs are broadcast to a common ``state_shape``, which is the maximum
of ``self.rate.value.shape`` and any batch dimension. Scalar inputs are
automatically tiled.
**NEST Compatibility:**
Reproduces NEST's non-waveform-relaxation update semantics:
- Delayed events use integer step delays (not continuous time).
- Outgoing diffusion coefficients are updated post-integration (not mid-step).
- No iterative waveform relaxation (NEST's WFR mode is not implemented).
Examples
--------
**Single step with constant input:**
.. code-block:: python
>>> import brainpy.state as bp
>>> import saiunit as u
>>> import brainstate
>>> model = bp.siegert_neuron(in_size=10, tau=2*u.ms)
>>> model.init_all_states()
>>> with brainstate.environ.context(dt=0.1*u.ms):
... rate = model.update(drift_input=12.0, diffusion_input=3.0)
>>> print(rate.shape) # (10,)
>>> print(model.rate.value) # Updated firing rates
**Using delayed events for network coupling:**
.. code-block:: python
>>> model.init_all_states()
>>> event = {'coeff': 50.0, 'drift_factor': 0.1, 'diffusion_factor': 0.05, 'delay_steps': 5}
>>> with brainstate.environ.context(dt=0.1*u.ms):
... for step in range(10):
... rate = model.update(delayed_diffusion_events=event if step == 0 else None)
... if step == 5:
... print(f"Event delivered at step {step}, rate = {rate[0]:.2f} Hz")
**Batch simulation with heterogeneous parameters:**
.. code-block:: python
>>> model = bp.siegert_neuron(in_size=100, tau=1*u.ms)
>>> model.init_all_states(batch_size=32) # 32 independent realizations
>>> drift = np.random.uniform(10, 20, size=(32, 100))
>>> with brainstate.environ.context(dt=0.1*u.ms):
... rate = model.update(drift_input=drift, diffusion_input=2.0)
>>> print(rate.shape) # (32, 100)
**Multiple simultaneous instant events:**
.. code-block:: python
>>> events = [
... {'coeff': 10.0, 'drift_factor': 1.0, 'diffusion_factor': 0.0},
... {'coeff': 5.0, 'drift_factor': 0.5, 'diffusion_factor': 0.1}
... ]
>>> model.init_all_states()
>>> with brainstate.environ.context(dt=0.1*u.ms):
... rate = model.update(instant_diffusion_events=events)
>>> print(f"Combined event effect: {rate.mean():.2f} Hz")
"""
ditype = brainstate.environ.ditype()
dftype = brainstate.environ.dftype()
h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms))
state_shape = self.rate.value.shape
if _precomputed_drive is not None:
# JIT-compatible path: bypass event queue and Siegert computation entirely.
drive = jnp.broadcast_to(jnp.asarray(_precomputed_drive, dtype=dftype), state_shape)
rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape)
tau = np.broadcast_to(self._to_numpy_ms(self.tau), state_shape)
mean = np.broadcast_to(self._to_numpy(self.mean), state_shape)
p1 = np.exp(-h / tau)
p2 = -np.expm1(-h / tau)
rate_new = p1 * rate_prev + p2 * (mean + drive)
self.rate.value = rate_new
self.delayed_rate.value = rate_new
self.instant_rate.value = rate_new
return rate_new
step_idx = int(np.asarray(self._step_count.value, dtype=ditype).reshape(-1)[0])
drift_delayed, diffusion_delayed = self._drain_delayed_queue(step_idx, state_shape)
d_now, s_now = self._schedule_delayed_events(
delayed_diffusion_events,
step_idx=step_idx,
state_shape=state_shape,
)
drift_delayed += d_now
diffusion_delayed += s_now
drift_instant, diffusion_instant = self._accumulate_instant_events(
instant_diffusion_events,
state_shape=state_shape,
)
# Keep compatibility with the standard Dynamics input hooks.
drift_direct = self._broadcast_to_state(
self._to_numpy(self.sum_current_inputs(x, self.rate.value) + drift_input + self.sum_delta_inputs(0.0)),
state_shape,
)
diffusion_direct = self._broadcast_to_state(self._to_numpy(diffusion_input), state_shape)
mu_total = drift_delayed + drift_instant + drift_direct
sigma_square_total = diffusion_delayed + diffusion_instant + diffusion_direct
rate_prev = self._broadcast_to_state(self._to_numpy(self.rate.value), state_shape)
tau = self._broadcast_to_state(self._to_numpy_ms(self.tau), state_shape)
mean = self._broadcast_to_state(self._to_numpy(self.mean), state_shape)
tau_m = self._broadcast_to_state(self._to_numpy_ms(self.tau_m), state_shape)
tau_syn = self._broadcast_to_state(self._to_numpy_ms(self.tau_syn), state_shape)
t_ref = self._broadcast_to_state(self._to_numpy_ms(self.t_ref), state_shape)
theta = self._broadcast_to_state(self._to_numpy(self.theta), state_shape)
v_reset = self._broadcast_to_state(self._to_numpy(self.V_reset), state_shape)
drive = self._siegert_array(mu_total, sigma_square_total, tau_m, tau_syn, t_ref, theta, v_reset)
p1 = np.exp(-h / tau)
p2 = -np.expm1(-h / tau)
rate_new = p1 * rate_prev + p2 * (mean + drive)
self.rate.value = rate_new
# NEST non-WFR update emits coefficient arrays overwritten by final rate.
self.delayed_rate.value = np.array(rate_new, dtype=dftype, copy=True)
self.instant_rate.value = np.array(rate_new, dtype=dftype, copy=True)
self._step_count.value = np.asarray(step_idx + 1, dtype=ditype)
return rate_new