# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import math
from typing import Callable
import brainstate
import braintools
import brainunit as u
import numpy as np
import jax.numpy as jnp
import jax.scipy.special as jax_special
from brainstate.typing import ArrayLike, Size
from brainpy_state._nest_base.base import NESTNeuron
from brainpy_state._nest_base.utils import is_tracer, cond_any
__all__ = [
'siegert_neuron',
]
try:
from scipy import integrate as _sp_integrate
from scipy import special as _sp_special
_HAVE_SCIPY = True
except Exception: # pragma: no cover - fallback path when SciPy is unavailable.
_HAVE_SCIPY = False
# Gauss-Legendre nodes used by the scalar quadrature helpers.
_GAUSS_NODES, _GAUSS_WEIGHTS = np.polynomial.legendre.leggauss(64)
class siegert_neuron(NESTNeuron):
r"""NEST-compatible ``siegert_neuron`` mean-field rate model.
**1. Overview**
Mean-field rate model using the Siegert gain function of a noisy LIF neuron.
This model computes the population-averaged firing rate from drift-diffusion
input statistics via an analytic transfer function, enabling efficient
large-scale network simulation without explicit spike generation.
**2. Mathematical Description**
The rate dynamics follow a first-order ODE:
.. math::
\tau\,\frac{dr(t)}{dt} = -r(t) + \text{mean} + \Phi(\mu, \sigma^2),
where:
- :math:`r(t)` is the population firing rate (Hz)
- :math:`\tau` is the rate time constant
- :math:`\text{mean}` is a constant baseline drive
- :math:`\Phi(\mu, \sigma^2)` is the Siegert transfer function
- :math:`\mu` is the total drift input (mean membrane potential shift)
- :math:`\sigma^2` is the total diffusion input (variance)
The Siegert function :math:`\Phi` analytically computes the steady-state
firing rate of a leaky integrate-and-fire neuron receiving white noise with
drift :math:`\mu` and diffusion :math:`\sigma^2`, subject to threshold
:math:`\theta`, reset :math:`V_{\text{reset}}`, refractory period
:math:`t_{\text{ref}}`, and membrane time constant :math:`\tau_m` [2]_.
For colored noise (finite :math:`\tau_{\text{syn}} > 0`), a threshold shift
correction is applied [3]_:
.. math::
\Delta_{\text{th}} = \frac{\alpha}{2} \sqrt{\frac{\tau_{\text{syn}}}{\tau_m}},
where :math:`\alpha = |\zeta(1/2)| \sqrt{2} \approx 2.0653`.
The integration is performed via exact exponential propagators:
.. math::
r(t + \Delta t) = e^{-\Delta t / \tau} r(t) + \left(1 - e^{-\Delta t / \tau}\right)
\left(\text{mean} + \Phi(\mu, \sigma^2)\right).
**3. NEST-Compatible Update Ordering (Non-WFR Path)**
For each simulation step:
1. Collect delayed and instantaneous diffusion-event buffers from queues.
2. Sum all drift and diffusion contributions (delayed, instant, direct inputs).
3. Evaluate Siegert transfer function :math:`\Phi(\mu_{\text{total}}, \sigma^2_{\text{total}})`.
4. Update rate via exact exponential step: :math:`r \leftarrow P_1 r + P_2 (\text{mean} + \Phi)`.
5. Publish updated rate to ``delayed_rate`` and ``instant_rate`` buffers for outgoing connections.
This mirrors NEST's non-waveform-relaxation ``update_`` semantics where
emitted diffusion coefficients are overwritten with the post-update rate.
**4. Diffusion Event Handling**
Runtime diffusion events modulate drift and diffusion inputs. Events can be
supplied via two channels:
- ``instant_diffusion_events``: applied in the current step (delay = 0)
- ``delayed_diffusion_events``: scheduled by integer ``delay_steps`` (default 1)
Event format supports dicts, tuples, or lists. Dict keys:
- ``coeff`` (or ``rate``/``value``): base coefficient
- ``drift_factor``: multiplier for drift contribution
- ``diffusion_factor``: multiplier for diffusion contribution
- ``weight``: connection weight (default 1)
- ``multiplicity``: event count (default 1)
- ``delay_steps`` (or ``delay``): integer delay in steps
Tuple/list format: ``(coeff, drift_factor, diffusion_factor, delay_steps, weight, multiplicity)``.
Shorter tuples use default values for trailing fields.
Drift and diffusion contributions are computed as:
.. math::
\mu &= \text{coeff} \times \text{weight} \times \text{multiplicity} \times \text{drift\_factor}, \\
\sigma^2 &= \text{coeff} \times \text{weight} \times \text{multiplicity} \times \text{diffusion\_factor}.
**5. Siegert Transfer Function Computation**
The Siegert function is evaluated element-wise for array inputs. For each
population element, the computation handles three regimes:
- **Deterministic (σ² ≤ 0)**: If μ > θ, returns LIF firing rate; else 0.
- **Very subthreshold (θ - μ > 6σ)**: Returns 0 (Brunel 2000 fast path).
- **General diffusive**: Computes via integral of scaled complementary error
function (erfcx) and Dawson's integral, using either SciPy (if available)
or custom Gauss-Legendre quadrature with asymptotic expansions.
Numerical integration uses 64-point Gauss-Legendre quadrature for erfcx and
adaptive segmentation for Dawson's integral, ensuring relative accuracy
~1.5e-8.
Parameters
----------
in_size : Size
Population shape. Tuple of ints or single int for 1D populations.
Determines the spatial structure of the rate model. For example,
``(10, 10)`` creates a 10×10 grid of mean-field neurons.
tau : Quantity[ms], optional
Time constant of the first-order rate dynamics (must be > 0). Controls
the rate of convergence to the steady-state Siegert value. Smaller
values produce faster tracking of input changes. Default: ``1 ms``.
tau_m : Quantity[ms], optional
Membrane time constant used in the Siegert gain function (must be > 0).
Represents the passive membrane time constant of the modeled LIF neurons.
Default: ``5 ms``.
tau_syn : Quantity[ms], optional
Synaptic time constant for colored-noise threshold correction (must be ≥ 0).
When ``tau_syn > 0``, applies a threshold shift to account for finite
synaptic rise time [3]_. Use ``0 ms`` for white noise (no correction).
Default: ``0 ms``.
t_ref : Quantity[ms], optional
Refractory period in the Siegert gain function (must be ≥ 0). Represents
the absolute refractory period during which the neuron cannot spike.
Increases the interspike interval and reduces firing rates. Default: ``2 ms``.
mean : float, optional
Constant additive baseline drive in the rate ODE (dimensionless). Shifts
the firing rate upward without affecting dynamics. Can be scalar or
array matching ``in_size``. Default: ``0.0``.
theta : float, optional
Spike threshold relative to resting potential (dimensionless, corresponds
to mV in NEST). Must be > ``V_reset``. Defines the firing threshold in
the Siegert transfer function. Default: ``15.0``.
V_reset : float, optional
Reset potential relative to resting potential (dimensionless, corresponds
to mV in NEST). Must be < ``theta``. Neuron is reset to this value after
spiking in the underlying LIF model. Default: ``0.0``.
rate_initializer : Callable, optional
Initializer function for the ``rate`` state variable. Called as
``rate_initializer(shape, batch_size)`` during ``init_state()``. Default:
``braintools.init.Constant(0.0)`` (all neurons start at 0 Hz).
name : str, optional
Unique identifier for this module. If ``None``, auto-generated. Used for
logging and debugging.
Parameter Mapping
-----------------
========================= ========================== ==================================
NEST Parameter brainpy.state Parameter Description
========================= ========================== ==================================
``tau`` ``tau`` Rate dynamics time constant
``tau_m`` ``tau_m`` Membrane time constant (Siegert)
``tau_syn`` ``tau_syn`` Synaptic time constant (threshold shift)
``t_ref`` ``t_ref`` Refractory period (Siegert)
``mean`` ``mean`` Constant baseline drive
``theta`` ``theta`` Spike threshold (relative to rest)
``V_reset`` ``V_reset`` Reset potential (relative to rest)
``rate`` ``rate`` Current firing rate (Hz)
========================= ========================== ==================================
Raises
------
ValueError
If ``tau`` ≤ 0, ``tau_m`` ≤ 0, ``tau_syn`` < 0, ``t_ref`` < 0, or ``V_reset`` ≥ ``theta``.
ValueError
If ``instant_diffusion_events`` contains non-zero ``delay_steps``.
ValueError
If ``delayed_diffusion_events`` contains negative ``delay_steps``.
ValueError
If event tuples have length > 6 or < 1.
Notes
-----
**Computational Complexity**
- Siegert evaluation is the primary bottleneck (O(N) per neuron).
- Without SciPy, custom quadrature adds ~10× overhead.
- Delayed event queues are sparse dicts (O(1) insertion, O(K) retrieval
for K active delays).
**Numerical Stability:**
- Uses ``erfcx(x) = exp(x²) erfc(x)`` to avoid overflow for large x.
- Asymptotic expansions for erfcx and Dawson's integral when x > 8.
- Exact exponential propagators (``exp`` and ``expm1``) prevent drift accumulation.
**Batch Dimensions:**
States support an optional leading batch dimension for parallelizing multiple
network realizations. Initialize with ``init_state(batch_size=B)`` to create
shape ``(B, *in_size)``.
**Integration with NEST:**
This implementation reproduces NEST 3.9+ behavior for ``siegert_neuron`` in
non-waveform-relaxation mode. Key differences:
- NEST uses precise spike times; brainpy.state uses fixed-step integration.
- NEST's WFR mode (iterative delay resolution) is not implemented.
- Event formats are compatible but may differ in edge cases (consult NEST docs).
References
----------
.. [1] Hahne J, Dahmen D, Schuecker J, Frommer A, Bolten M, Helias M,
Diesmann M (2017). Integration of continuous-time dynamics in a spiking
neural network simulator. Frontiers in Neuroinformatics, 11:34.
DOI: ``10.3389/fninf.2017.00034``.
.. [2] Fourcaud N, Brunel N (2002). Dynamics of the firing probability of
noisy integrate-and-fire neurons. Neural Computation, 14(9):2057-2110.
DOI: ``10.1162/089976602320264015``.
.. [3] Schuecker J, Diesmann M, Helias M (2015). Modulated escape from a
metastable state driven by colored noise. Physical Review E, 92:052119.
DOI: ``10.1103/PhysRevE.92.052119``.
Examples
--------
**Basic usage with constant input:**
.. code-block:: python
>>> from brainpy import state as bp
>>> import brainunit as u
>>> import brainstate
>>> model = bp.siegert_neuron(in_size=10, tau=2*u.ms, tau_m=10*u.ms)
>>> model.init_all_states()
>>> with brainstate.environ.context(dt=0.1*u.ms):
... for _ in range(100):
... rate = model.update(drift_input=12.0, diffusion_input=4.0)
>>> print(rate) # Steady-state firing rate in Hz
**Using diffusion events for network coupling:**
.. code-block:: python
>>> model.init_all_states()
>>> event = {'coeff': 50.0, 'drift_factor': 0.1, 'diffusion_factor': 0.05, 'delay_steps': 1}
>>> with brainstate.environ.context(dt=0.1*u.ms):
... rate = model.update(delayed_diffusion_events=event)
>>> print(model.rate.value) # Rate after delayed event delivery
**Mean-field network with recurrent connections:**
.. code-block:: python
>>> exc = bp.siegert_neuron(in_size=800, tau=1*u.ms, theta=15.0)
>>> inh = bp.siegert_neuron(in_size=200, tau=1*u.ms, theta=15.0)
>>> exc.init_all_states()
>>> inh.init_all_states()
>>> # Simulate recurrent network (conceptual; requires projection setup)
>>> with brainstate.environ.context(dt=0.1*u.ms):
... for t in range(1000):
... exc_drive = exc.rate.value.sum() * 0.01
... inh_drive = inh.rate.value.sum() * -0.02
... exc.update(drift_input=exc_drive + inh_drive, diffusion_input=2.0)
... inh.update(drift_input=exc_drive, diffusion_input=1.0)
"""
__module__ = 'brainpy.state'
# NEST value: alpha = |zeta(1/2)| * sqrt(2)
_ALPHA = 2.0652531522312172
# Seam-(H) continuous-rate emitter: the Simulator allocates an emission holder
# and captures ``rate`` each step so an outgoing diffusion_connection can read
# the previous step's rate (NEST min_delay=1). See _network/_simulator.py.
_emission_continuous = True
_emission_attr = 'rate'
def __init__(
self,
in_size: Size,
tau: ArrayLike = 1.0 * u.ms,
tau_m: ArrayLike = 5.0 * u.ms,
tau_syn: ArrayLike = 0.0 * u.ms,
t_ref: ArrayLike = 2.0 * u.ms,
mean: ArrayLike = 0.0,
theta: ArrayLike = 15.0,
V_reset: ArrayLike = 0.0,
rate_initializer: Callable = braintools.init.Constant(0.0),
name: str = None,
):
super().__init__(in_size=in_size, name=name)
self.tau = braintools.init.param(tau, self.varshape)
self.tau_m = braintools.init.param(tau_m, self.varshape)
self.tau_syn = braintools.init.param(tau_syn, self.varshape)
self.t_ref = braintools.init.param(t_ref, self.varshape)
self.mean = braintools.init.param(mean, self.varshape)
self.theta = braintools.init.param(theta, self.varshape)
self.V_reset = braintools.init.param(V_reset, self.varshape)
self.rate_initializer = rate_initializer
self._validate_parameters()
@property
def recordables(self):
return ['rate']
@property
def receptor_types(self):
# NEST handles DiffusionConnectionEvent via receptor type 1.
return {'DIFFUSION': 1}
@staticmethod
def _to_numpy(x):
dftype = brainstate.environ.dftype()
return np.asarray(u.get_mantissa(x), dtype=dftype)
@staticmethod
def _to_numpy_ms(x):
dftype = brainstate.environ.dftype()
return np.asarray(u.get_mantissa(x / u.ms), dtype=dftype)
@staticmethod
def _broadcast_to_state(x_np: np.ndarray, shape):
return np.broadcast_to(x_np, shape)
def _validate_parameters(self):
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.tau, self.tau_m, self.tau_syn, self.t_ref, self.V_reset, self.theta)):
return
if cond_any(self.tau <= 0.0 * u.ms):
raise ValueError('Time constant tau must be > 0.')
if cond_any(self.tau_m <= 0.0 * u.ms):
raise ValueError('Membrane time constant tau_m must be > 0.')
if cond_any(self.tau_syn < 0.0 * u.ms):
raise ValueError('Synaptic time constant tau_syn must be >= 0.')
if cond_any(self.t_ref < 0.0 * u.ms):
raise ValueError('Refractory period t_ref must be >= 0.')
if cond_any(self.V_reset >= self.theta):
raise ValueError('Reset potential V_reset must be smaller than threshold theta.')
[docs]
def init_state(self, **kwargs):
rate = braintools.init.param(self.rate_initializer, self.varshape)
rate_np = self._to_numpy(rate)
self.rate = brainstate.ShortTermState(rate_np)
dftype = brainstate.environ.dftype()
self.instant_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True))
self.delayed_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True))
@staticmethod
def _gauss_legendre_scalar_integral(func, a: float, b: float):
mid = 0.5 * (a + b)
half = 0.5 * (b - a)
pts = mid + half * _GAUSS_NODES
dftype = brainstate.environ.dftype()
vals = np.asarray([func(float(x)) for x in pts], dtype=dftype)
return float(half * np.sum(_GAUSS_WEIGHTS * vals))
@staticmethod
def _erfcx_pos_scalar(x: float):
if _HAVE_SCIPY:
return float(_sp_special.erfcx(x))
if x < 25.0:
return math.exp(x * x) * math.erfc(x)
inv = 1.0 / x
inv2 = inv * inv
poly = 1.0 + 0.5 * inv2 + 0.75 * inv2 * inv2 + 1.875 * inv2 ** 3 + 6.5625 * inv2 ** 4
return (inv / math.sqrt(math.pi)) * poly
@staticmethod
def _integral_erfcx_asympt(a: float, b: float):
inv_a2 = 1.0 / (a * a)
inv_b2 = 1.0 / (b * b)
term0 = math.log(b / a)
term1 = -0.25 * (inv_b2 - inv_a2)
term2 = -(3.0 / 16.0) * (inv_b2 * inv_b2 - inv_a2 * inv_a2)
term3 = -(5.0 / 16.0) * (inv_b2 ** 3 - inv_a2 ** 3)
term4 = -(105.0 / 128.0) * (inv_b2 ** 4 - inv_a2 ** 4)
return (term0 + term1 + term2 + term3 + term4) / math.sqrt(math.pi)
@classmethod
def _integral_erfcx_pos(cls, a: float, b: float):
if a == b:
return 0.0
sign = 1.0
lo = float(a)
hi = float(b)
if lo > hi:
sign = -1.0
lo, hi = hi, lo
if _HAVE_SCIPY:
result, _ = _sp_integrate.quad(
lambda s: float(_sp_special.erfcx(s)),
lo,
hi,
epsabs=0.0,
epsrel=1.49e-8,
limit=1000,
)
return sign * float(result)
split = 8.0
total = 0.0
if lo < split:
hi_num = min(hi, split)
width = hi_num - lo
nseg = max(1, int(math.ceil(width / 2.0)))
seg_w = width / nseg
left = lo
for _ in range(nseg):
right = left + seg_w
total += cls._gauss_legendre_scalar_integral(cls._erfcx_pos_scalar, left, right)
left = right
if hi > split:
lo_as = max(lo, split)
total += cls._integral_erfcx_asympt(lo_as, hi)
return sign * total
@classmethod
def _dawsn_pos_scalar(cls, x: float):
if _HAVE_SCIPY:
return float(_sp_special.dawsn(x))
if x == 0.0:
return 0.0
if x < 0.2:
x2 = x * x
return x * (
1.0
- (2.0 / 3.0) * x2
+ (4.0 / 15.0) * x2 * x2
- (8.0 / 105.0) * x2 ** 3
+ (16.0 / 945.0) * x2 ** 4
)
if x >= 8.0:
inv = 1.0 / x
inv2 = inv * inv
return (
0.5 * inv
+ 0.25 * inv * inv2
+ (3.0 / 8.0) * inv * inv2 ** 2
+ (15.0 / 16.0) * inv * inv2 ** 3
+ (105.0 / 32.0) * inv * inv2 ** 4
)
# Dawson(x) = exp(-x^2) * integral_0^x exp(t^2) dt
nseg = max(1, int(math.ceil(x / 1.0)))
seg_w = x / nseg
left = 0.0
integral = 0.0
for _ in range(nseg):
right = left + seg_w
integral += cls._gauss_legendre_scalar_integral(lambda t: math.exp(t * t), left, right)
left = right
return math.exp(-x * x) * integral
@classmethod
def _siegert_scalar(
cls,
mu: float,
sigma_square: float,
tau_m_ms: float,
tau_syn_ms: float,
t_ref_ms: float,
theta: float,
v_reset: float,
):
if sigma_square <= 0.0:
if mu > theta:
return 1e3 / (t_ref_ms + tau_m_ms * math.log((mu - v_reset) / (mu - theta)))
return 0.0
sigma = math.sqrt(sigma_square)
# NEST fast path for very subthreshold input (Brunel 2000, eq. 22 estimate).
if (theta - mu) > 6.0 * sigma:
return 0.0
threshold_shift = (cls._ALPHA / 2.0) * math.sqrt(tau_syn_ms / tau_m_ms)
y_th = (theta - mu) / sigma + threshold_shift
y_r = (v_reset - mu) / sigma + threshold_shift
sqrt_pi = math.sqrt(math.pi)
if y_r > 0.0:
result = cls._integral_erfcx_pos(y_r, y_th)
integral = (
2.0 * cls._dawsn_pos_scalar(y_th)
- 2.0 * math.exp(y_r * y_r - y_th * y_th) * cls._dawsn_pos_scalar(y_r)
- math.exp(-y_th * y_th) * result
)
e = math.exp(-y_th * y_th)
return 1e3 * e / (e * t_ref_ms + tau_m_ms * sqrt_pi * integral)
if y_th < 0.0:
integral = cls._integral_erfcx_pos(-y_th, -y_r)
return 1e3 / (t_ref_ms + tau_m_ms * sqrt_pi * integral)
result = cls._integral_erfcx_pos(y_th, -y_r)
integral = 2.0 * cls._dawsn_pos_scalar(y_th) + math.exp(-y_th * y_th) * result
e = math.exp(-y_th * y_th)
return 1e3 * e / (e * t_ref_ms + tau_m_ms * sqrt_pi * integral)
@classmethod
def _siegert_array(
cls,
mu: np.ndarray,
sigma_square: np.ndarray,
tau_m_ms: np.ndarray,
tau_syn_ms: np.ndarray,
t_ref_ms: np.ndarray,
theta: np.ndarray,
v_reset: np.ndarray,
):
dftype = brainstate.environ.dftype()
out = np.empty_like(mu, dtype=dftype)
for idx in np.ndindex(mu.shape):
out[idx] = cls._siegert_scalar(
float(mu[idx]),
float(sigma_square[idx]),
float(tau_m_ms[idx]),
float(tau_syn_ms[idx]),
float(t_ref_ms[idx]),
float(theta[idx]),
float(v_reset[idx]),
)
return out
# ------------------------------------------------------------------
# JAX-native Siegert transfer (goal 15c, design B).
#
# The host ``_siegert_scalar`` path above (SciPy / numpy Gauss-Legendre)
# is the quadrature *oracle*; it stays eager and drives ``siegert_rate``.
# The ``*_jax`` methods below re-express the same three-branch algorithm in
# ``jax.numpy`` so ``update`` lowers under ``brainstate.transform.for_loop``
# / ``jit``. They are validated against the oracle in
# ``_validation/siegert_diffusion_test.py``.
# ------------------------------------------------------------------
@staticmethod
def _erfcx_jax(x):
r"""Scaled complementary error function ``erfcx(x) = exp(x^2) erfc(x)``.
Direct ``exp(x^2) erfc(x)`` for ``x < 8`` (clipped to avoid overflow in the
unused branch); a seven-term ``1/x`` asymptotic series for ``x >= 8``.
"""
x = jnp.asarray(x)
inv = 1.0 / jnp.where(x != 0.0, x, 1.0)
inv2 = inv * inv
# erfcx(x) ~ 1/(x sqrt(pi)) * sum_k (-1)^k (2k-1)!!/(2x^2)^k (alternating).
poly = (1.0 - 0.5 * inv2 + 0.75 * inv2 ** 2 - 1.875 * inv2 ** 3
+ 6.5625 * inv2 ** 4 - 29.53125 * inv2 ** 5 + 162.421875 * inv2 ** 6)
asympt = inv / jnp.sqrt(jnp.pi) * poly
x_safe = jnp.minimum(x, 8.0)
direct = jnp.exp(x_safe * x_safe) * jax_special.erfc(x_safe)
return jnp.where(x < 8.0, direct, asympt)
@classmethod
def _dawsn_jax(cls, x):
r"""Dawson's integral ``D(x) = exp(-x^2) \int_0^x exp(t^2) dt`` (odd in x).
Taylor series for ``|x| < 0.2``; a five-term ``1/x`` asymptotic series for
``|x| >= 8``; an 8-segment 64-point Gauss-Legendre quadrature of
``\int_0^{|x|} exp(t^2) dt`` (segments of width ``<= 1``) in between. The
argument to ``exp`` is clipped to ``[., 8]`` so the (unused) mid branch
stays finite for large ``|x|``.
"""
x = jnp.asarray(x)
ax = jnp.abs(x)
sgn = jnp.sign(x)
x2 = ax * ax
taylor = ax * (1.0 - (2.0 / 3.0) * x2 + (4.0 / 15.0) * x2 ** 2
- (8.0 / 105.0) * x2 ** 3 + (16.0 / 945.0) * x2 ** 4)
inv = 1.0 / jnp.where(ax > 0.0, ax, 1.0)
inv2 = inv * inv
asympt = (0.5 * inv + 0.25 * inv * inv2 + (3.0 / 8.0) * inv * inv2 ** 2
+ (15.0 / 16.0) * inv * inv2 ** 3 + (105.0 / 32.0) * inv * inv2 ** 4
+ (945.0 / 64.0) * inv * inv2 ** 5 + (10395.0 / 128.0) * inv * inv2 ** 6)
nseg = 8
nodes = jnp.asarray(_GAUSS_NODES)
weights = jnp.asarray(_GAUSS_WEIGHTS)
seg_w = ax / nseg
k = jnp.arange(nseg)
left = seg_w[..., None] * k
mid = left + 0.5 * seg_w[..., None]
half_seg = 0.5 * seg_w
pts = mid[..., None] + half_seg[..., None, None] * nodes
integrand = jnp.exp(jnp.minimum(pts, 8.0) ** 2)
integ_per_seg = jnp.sum(weights * integrand, axis=-1) * half_seg[..., None]
integral = jnp.sum(integ_per_seg, axis=-1)
mid_val = jnp.exp(-ax * ax) * integral
out = jnp.where(ax < 0.2, taylor, jnp.where(ax >= 8.0, asympt, mid_val))
return sgn * out
@staticmethod
def _integral_erfcx_asympt_jax(a, b):
r"""Closed-form ``\int_a^b erfcx(s) ds`` via the ``1/s`` asymptotic series."""
inv_a2 = 1.0 / (a * a)
inv_b2 = 1.0 / (b * b)
# Antiderivative of the alternating erfcx asymptotic series; the odd-order
# terms are +, the even-order terms - (integral of (-1)^k (2k-1)!!/(2s^2)^k).
term0 = jnp.log(b / a)
term1 = 0.25 * (inv_b2 - inv_a2)
term2 = -(3.0 / 16.0) * (inv_b2 ** 2 - inv_a2 ** 2)
term3 = (5.0 / 16.0) * (inv_b2 ** 3 - inv_a2 ** 3)
term4 = -(105.0 / 128.0) * (inv_b2 ** 4 - inv_a2 ** 4)
term5 = (945.0 / 320.0) * (inv_b2 ** 5 - inv_a2 ** 5)
return (term0 + term1 + term2 + term3 + term4 + term5) / jnp.sqrt(jnp.pi)
@classmethod
def _integral_erfcx_jax(cls, a, b):
r"""``\int_a^b erfcx(s) ds`` for non-negative bounds (signed in ``b - a``).
``[lo, min(hi, 8)]`` is integrated with a fixed 64-point Gauss-Legendre rule
(``erfcx`` is smooth and bounded there); ``[max(lo, 8), hi]`` uses the
closed-form asymptotic antiderivative.
"""
a = jnp.asarray(a)
b = jnp.asarray(b)
lo = jnp.minimum(a, b)
hi = jnp.maximum(a, b)
sign = jnp.sign(b - a)
split = 8.0
c = jnp.minimum(hi, split)
nodes = jnp.asarray(_GAUSS_NODES)
weights = jnp.asarray(_GAUSS_WEIGHTS)
mid = 0.5 * (lo + c)
half = 0.5 * (c - lo)
pts = mid[..., None] + half[..., None] * nodes
gl = half * jnp.sum(weights * cls._erfcx_jax(pts), axis=-1)
gl = jnp.where(lo < split, gl, 0.0)
d = jnp.maximum(lo, split)
asy = cls._integral_erfcx_asympt_jax(d, hi)
asy = jnp.where(hi > split, asy, 0.0)
return sign * (gl + asy)
@classmethod
def _siegert_phi_core(cls, mu, sigma_square, tau_m_ms, tau_syn_ms, t_ref_ms, theta, v_reset):
r"""JAX three-branch Siegert transfer on broadcast arrays (Hz)."""
mu = jnp.asarray(mu)
sig2 = jnp.asarray(sigma_square)
sqrt_pi = jnp.sqrt(jnp.pi)
# Deterministic LIF (sigma^2 <= 0): guard the log argument to stay finite.
gap = jnp.where(mu > theta, mu - theta, 1.0)
ratio = jnp.where(mu > theta, (mu - v_reset) / gap, 2.0)
det = jnp.where(mu > theta, 1e3 / (t_ref_ms + tau_m_ms * jnp.log(ratio)), 0.0)
sigma = jnp.sqrt(jnp.maximum(sig2, 1e-12))
shift = (cls._ALPHA / 2.0) * jnp.sqrt(tau_syn_ms / tau_m_ms)
y_th = (theta - mu) / sigma + shift
y_r = (v_reset - mu) / sigma + shift
e_th = jnp.exp(-y_th * y_th)
# Clamp heavy-function arguments to their valid (non-negative) ranges; this
# is a no-op in each branch's *selected* region and keeps the unused branches
# finite (value-safe jnp.where).
yth_p = jnp.maximum(y_th, 0.0)
yr_p = jnp.maximum(y_r, 0.0)
myth_p = jnp.maximum(-y_th, 0.0)
myr_p = jnp.maximum(-y_r, 0.0)
# Branch A: y_r > 0.
iA = cls._integral_erfcx_jax(yr_p, yth_p)
expd = jnp.exp(jnp.minimum(y_r * y_r - y_th * y_th, 0.0))
integ_A = 2.0 * cls._dawsn_jax(yth_p) - 2.0 * expd * cls._dawsn_jax(yr_p) - e_th * iA
rate_A = 1e3 * e_th / (e_th * t_ref_ms + tau_m_ms * sqrt_pi * integ_A)
# Branch B: y_th < 0.
iB = cls._integral_erfcx_jax(myth_p, myr_p)
rate_B = 1e3 / (t_ref_ms + tau_m_ms * sqrt_pi * iB)
# Branch C: y_r <= 0 <= y_th.
iC = cls._integral_erfcx_jax(yth_p, myr_p)
integ_C = 2.0 * cls._dawsn_jax(yth_p) + e_th * iC
rate_C = 1e3 * e_th / (e_th * t_ref_ms + tau_m_ms * sqrt_pi * integ_C)
rate = jnp.where(y_r > 0.0, rate_A, jnp.where(y_th < 0.0, rate_B, rate_C))
# Brunel (2000) deep-subthreshold fast path.
rate = jnp.where((theta - mu) > 6.0 * sigma, 0.0, rate)
rate = jnp.where(sig2 <= 0.0, det, rate)
return jnp.maximum(rate, 0.0)
def _siegert_phi_jax(self, mu: ArrayLike, sigma_square: ArrayLike):
r"""Evaluate the JAX Siegert transfer with this model's parameters (Hz).
JAX-lowering counterpart of :meth:`siegert_rate`: ``mu`` / ``sigma_square``
may be tracers, the model parameters are folded in as static constants, and
the result is a ``jax.numpy`` array that composes under
``brainstate.transform`` primitives.
Parameters
----------
mu : ArrayLike
Drift input (mean membrane potential shift), broadcastable with
``sigma_square`` and the model parameters.
sigma_square : ArrayLike
Diffusion input (membrane potential variance); non-negative.
Returns
-------
rate : jax.Array
Firing rate in Hz (broadcast shape of inputs and parameters).
See Also
--------
siegert_rate : Eager SciPy / Gauss-Legendre quadrature oracle.
"""
theta = jnp.asarray(self._to_numpy(self.theta))
v_reset = jnp.asarray(self._to_numpy(self.V_reset))
tau_m_ms = jnp.asarray(self._to_numpy_ms(self.tau_m))
tau_syn_ms = jnp.asarray(self._to_numpy_ms(self.tau_syn))
t_ref_ms = jnp.asarray(self._to_numpy_ms(self.t_ref))
return self._siegert_phi_core(
jnp.asarray(mu), jnp.asarray(sigma_square),
tau_m_ms, tau_syn_ms, t_ref_ms, theta, v_reset,
)
[docs]
def siegert_rate(self, mu: ArrayLike, sigma_square: ArrayLike):
r"""Evaluate the NEST-compatible Siegert transfer function.
Computes the steady-state firing rate :math:`\Phi(\mu, \sigma^2)` of a
noisy LIF neuron with drift ``mu`` and diffusion ``sigma_square``, using
the analytic Siegert formula [2]_ with optional colored-noise correction [3]_.
The computation is vectorized over population elements. Inputs are broadcast
with model parameters (``theta``, ``tau_m``, etc.) to produce an output
array matching the broadcast shape.
Parameters
----------
mu : ArrayLike
Drift input (mean membrane potential shift, dimensionless). Scalar or
array broadcastable with ``sigma_square`` and model parameters. Positive
values depolarize the neuron. Typically in the range [0, 30] for
physiological parameters.
sigma_square : ArrayLike
Diffusion input (membrane potential variance, dimensionless squared).
Scalar or array broadcastable with ``mu`` and model parameters. Must be
non-negative. Typical values: 0.1–10 for moderate noise. Zero produces
deterministic LIF behavior.
Returns
-------
rate : ndarray
Firing rate in Hz (shape matches broadcast of inputs and model parameters).
Values ≥ 0. Returns 0 for subthreshold inputs (μ < θ with low noise).
Maximum rate is approximately ``1000 / t_ref`` Hz (refractory limit).
Notes
-----
**Special Cases:**
- If ``sigma_square`` ≤ 0: deterministic LIF (returns 0 if μ ≤ θ, else
fires at constant-input rate).
- If (θ - μ) > 6σ: deep subthreshold (returns 0, Brunel 2000 fast path).
- If ``t_ref`` = 0: no refractory limit (rate can diverge for μ >> θ).
**Performance:**
- Without SciPy: uses 64-point Gauss-Legendre quadrature (~10× slower).
- With SciPy: uses ``scipy.integrate.quad`` and ``scipy.special`` (faster).
**Broadcasting Rules:**
Output shape is ``np.broadcast(mu, sigma_square, theta).shape``. For
example, if model has ``in_size=(10,)``, ``mu`` is scalar, and
``sigma_square`` has shape ``(10,)``, output shape is ``(10,)``.
Examples
--------
**Single neuron with varying drift:**
.. code-block:: python
>>> from brainpy import state as bp
>>> import numpy as np
>>> import brainunit as u
>>> model = bp.siegert_neuron(in_size=1, tau_m=10*u.ms, t_ref=2*u.ms, theta=15.0)
>>> mu_vals = np.linspace(0, 25, 50)
>>> rates = model.siegert_rate(mu=mu_vals, sigma_square=2.0)
>>> print(rates.shape) # (50,)
>>> print(rates.max()) # Maximum firing rate in Hz
**Population with heterogeneous noise:**
.. code-block:: python
>>> model = bp.siegert_neuron(in_size=100, tau_m=10*u.ms)
>>> sigma_sq = np.linspace(0.1, 5.0, 100)
>>> rates = model.siegert_rate(mu=15.0, sigma_square=sigma_sq)
>>> print(rates.shape) # (100,)
**2D grid with spatially varying input:**
.. code-block:: python
>>> model = bp.siegert_neuron(in_size=(10, 10), tau_m=10*u.ms)
>>> mu_grid = np.random.uniform(10, 20, size=(10, 10))
>>> rates = model.siegert_rate(mu=mu_grid, sigma_square=3.0)
>>> print(rates.shape) # (10, 10)
"""
mu_np = self._to_numpy(mu)
sigma_np = self._to_numpy(sigma_square)
state_shape = np.broadcast(
mu_np,
sigma_np,
self._to_numpy(self.theta),
).shape
mu_b = self._broadcast_to_state(mu_np, state_shape)
sigma_b = self._broadcast_to_state(sigma_np, state_shape)
tau_m_b = self._broadcast_to_state(self._to_numpy_ms(self.tau_m), state_shape)
tau_syn_b = self._broadcast_to_state(self._to_numpy_ms(self.tau_syn), state_shape)
t_ref_b = self._broadcast_to_state(self._to_numpy_ms(self.t_ref), state_shape)
theta_b = self._broadcast_to_state(self._to_numpy(self.theta), state_shape)
v_reset_b = self._broadcast_to_state(self._to_numpy(self.V_reset), state_shape)
return self._siegert_array(
mu_b,
sigma_b,
tau_m_b,
tau_syn_b,
t_ref_b,
theta_b,
v_reset_b,
)
[docs]
def update(self, x=0.0, drift_input: ArrayLike = 0.0, diffusion_input: ArrayLike = 0.0):
r"""Advance the Siegert rate by one step (NEST non-WFR semantics).
Drift and diffusion are read from the dual-channel substrate seam that a
:class:`~brainpy_state.diffusion_connection` deposits into (goal 15c,
design A):
- drift :math:`\mu = \mathrm{sum\_current\_inputs}(x, r)
+ \mathrm{drift\_input} + \mathrm{sum\_delta\_inputs}(0,\ \text{label}=`
``'diffusion_mu'`` :math:`)`,
- diffusion :math:`\sigma^2 = \mathrm{diffusion\_input}
+ \mathrm{sum\_delta\_inputs}(0,\ \text{label}=` ``'diffusion_sigma2'``
:math:`)`.
The two channels carry distinct labels so a single ``diffusion_connection``
making two seam deposits (``drift_factor * rate`` and
``diffusion_factor * rate``) never cross-contaminates :math:`\mu` and
:math:`\sigma^2`. The rate then relaxes by the exact exponential propagator
:math:`r \leftarrow P_1 r + P_2(\mathrm{mean} + \Phi)`.
The Siegert transfer is evaluated with the JAX port :meth:`_siegert_phi_jax`,
so the whole step lowers under ``brainstate.transform.for_loop`` / ``jit``
(drive it with those, not a bare Python loop).
Parameters
----------
x : ArrayLike, optional
External drive forwarded to ``sum_current_inputs`` (dimensionless).
drift_input : ArrayLike, optional
Direct drift contribution added to the total drift before the Siegert
evaluation.
diffusion_input : ArrayLike, optional
Direct diffusion (variance) contribution; must be non-negative.
Returns
-------
rate_new : jax.Array
Updated firing rate in Hz (shape ``self.rate.value.shape``).
"""
h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms))
state_shape = self.rate.value.shape
rate_prev = jnp.asarray(self.rate.value)
# Drift: current-input seam + direct arg + labeled diffusion-drift channel.
mu_total = (self.sum_current_inputs(x, rate_prev) + drift_input
+ self.sum_delta_inputs(0.0, label='diffusion_mu'))
# Diffusion (variance): direct arg + labeled diffusion-variance channel.
sigma_square_total = (diffusion_input
+ self.sum_delta_inputs(0.0, label='diffusion_sigma2'))
drive = self._siegert_phi_jax(mu_total, sigma_square_total)
tau = jnp.asarray(self._to_numpy_ms(self.tau))
mean = jnp.asarray(self._to_numpy(self.mean))
p1 = jnp.exp(-h / tau)
p2 = -jnp.expm1(-h / tau)
rate_new = jnp.broadcast_to(p1 * rate_prev + p2 * (mean + drive), state_shape)
self.rate.value = rate_new
# NEST non-WFR: outgoing delayed/instant buffers carry the final rate.
self.delayed_rate.value = rate_new
self.instant_rate.value = rate_new
return rate_new