# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable
import brainstate
import braintools
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from brainpy_state._nest.lin_rate import _lin_rate_base
from ._utils import is_tracer
__all__ = [
'gauss_rate_ipn',
]
class _gauss_rate_base(_lin_rate_base):
__module__ = 'brainpy.state'
def _input(self, h, g, mu, sigma):
return g * np.exp(-np.power(h - mu, 2.0) / (2.0 * np.power(sigma, 2.0)))
@staticmethod
def _mult_coupling_ex(rate):
dftype = brainstate.environ.dftype()
return jnp.ones_like(rate, dtype=dftype)
@staticmethod
def _mult_coupling_in(rate):
dftype = brainstate.environ.dftype()
return jnp.ones_like(rate, dtype=dftype)
def _extract_event_fields(self, ev, default_delay_steps: int):
if isinstance(ev, dict):
rate = ev.get('rate', ev.get('coeff', ev.get('value', 0.0)))
weight = ev.get('weight', 1.0)
multiplicity = ev.get('multiplicity', 1.0)
delay_steps = ev.get('delay_steps', ev.get('delay', default_delay_steps))
elif isinstance(ev, (tuple, list)):
if len(ev) == 2:
rate, weight = ev
delay_steps = default_delay_steps
multiplicity = 1.0
elif len(ev) == 3:
rate, weight, delay_steps = ev
multiplicity = 1.0
elif len(ev) == 4:
rate, weight, delay_steps, multiplicity = ev
else:
raise ValueError('Rate event tuples must have length 2, 3, or 4.')
else:
rate = ev
weight = 1.0
multiplicity = 1.0
delay_steps = default_delay_steps
delay_steps = self._to_int_scalar(delay_steps, name='delay_steps')
return rate, weight, multiplicity, delay_steps
def _event_to_ex_in(self, ev, default_delay_steps: int, state_shape, g, mu, sigma):
rate, weight, multiplicity, delay_steps = self._extract_event_fields(ev, default_delay_steps)
rate_np = self._broadcast_to_state(self._to_numpy(rate), state_shape)
weight_np = self._broadcast_to_state(self._to_numpy(weight), state_shape)
multiplicity_np = self._broadcast_to_state(self._to_numpy(multiplicity), state_shape)
dftype = brainstate.environ.dftype()
weight_sign = self._broadcast_to_state(
np.asarray(u.get_mantissa(weight), dtype=dftype) >= 0.0,
state_shape,
)
if self.linear_summation:
weighted_value = rate_np * weight_np * multiplicity_np
else:
weighted_value = self._input(rate_np, g, mu, sigma) * weight_np * multiplicity_np
ex = np.where(weight_sign, weighted_value, 0.0)
inh = np.where(weight_sign, 0.0, weighted_value)
return ex, inh, delay_steps
def _accumulate_instant_events_gauss(self, events, state_shape, g, mu, sigma):
dftype = brainstate.environ.dftype()
ex = np.zeros(state_shape, dtype=dftype)
inh = np.zeros(state_shape, dtype=dftype)
for ev in self._coerce_events(events):
ex_i, inh_i, delay_steps = self._event_to_ex_in(
ev,
default_delay_steps=0,
state_shape=state_shape,
g=g,
mu=mu,
sigma=sigma,
)
if delay_steps != 0:
raise ValueError('instant_rate_events must not specify non-zero delay_steps.')
ex += ex_i
inh += inh_i
return ex, inh
def _schedule_delayed_events_gauss(self, events, step_idx: int, state_shape, g, mu, sigma):
dftype = brainstate.environ.dftype()
ex_now = np.zeros(state_shape, dtype=dftype)
inh_now = np.zeros(state_shape, dtype=dftype)
for ev in self._coerce_events(events):
ex_i, inh_i, delay_steps = self._event_to_ex_in(
ev,
default_delay_steps=1,
state_shape=state_shape,
g=g,
mu=mu,
sigma=sigma,
)
if delay_steps < 0:
raise ValueError('delay_steps for delayed_rate_events must be >= 0.')
if delay_steps == 0:
ex_now += ex_i
inh_now += inh_i
else:
target_step = step_idx + delay_steps
self._queue_add(self._delayed_ex_queue, target_step, ex_i)
self._queue_add(self._delayed_in_queue, target_step, inh_i)
return ex_now, inh_now
def _common_inputs_gauss(self, x, instant_rate_events, delayed_rate_events, g, mu, sigma):
state_shape = self.rate.value.shape
step_idx = self._step_count
delayed_ex, delayed_in = self._drain_delayed_queue(step_idx, state_shape)
delayed_ex_now, delayed_in_now = self._schedule_delayed_events_gauss(
delayed_rate_events,
step_idx=step_idx,
state_shape=state_shape,
g=g,
mu=mu,
sigma=sigma,
)
delayed_ex = delayed_ex + delayed_ex_now
delayed_in = delayed_in + delayed_in_now
instant_ex, instant_in = self._accumulate_instant_events_gauss(
instant_rate_events,
state_shape=state_shape,
g=g,
mu=mu,
sigma=sigma,
)
delta_input = self._broadcast_to_state(self._to_numpy(self.sum_delta_inputs(0.0)), state_shape)
instant_ex += np.where(delta_input > 0.0, delta_input, 0.0)
instant_in += np.where(delta_input < 0.0, delta_input, 0.0)
mu_ext = self._broadcast_to_state(self._to_numpy(self.sum_current_inputs(x, self.rate.value)), state_shape)
return state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext
def _common_parameters_gauss(self, state_shape):
tau = self._broadcast_to_state(self._to_numpy_ms(self.tau), state_shape)
sigma = self._broadcast_to_state(self._to_numpy(self.sigma), state_shape)
mu = self._broadcast_to_state(self._to_numpy(self.mu), state_shape)
g = self._broadcast_to_state(self._to_numpy(self.g), state_shape)
return tau, sigma, mu, g
class gauss_rate_ipn(_gauss_rate_base):
r"""NEST-compatible ``gauss_rate_ipn`` nonlinear rate neuron with input noise.
Implements a stochastic rate-based neuron with Gaussian gain function and input
noise, matching NEST's ``gauss_rate_ipn`` model. The dynamics combine passive
decay, mean drive, network input (processed through a Gaussian nonlinearity),
and additive Brownian noise.
**1. Model equations**
The stochastic differential equation governing the rate dynamics is:
.. math::
\tau\,dX(t)=
\left[-\lambda X(t)+\mu+I_{\mathrm{net}}(t)\right]dt
+\left[\sqrt{\tau}\,\sigma\right]dW(t),
where :math:`W(t)` is a standard Wiener process and :math:`I_{\mathrm{net}}(t)`
is the effective network input after applying the Gaussian gain function
:math:`\phi(h)`:
.. math::
\phi(h)=g\exp\left(-\frac{(h-\mu)^2}{2\sigma^2}\right).
The gain function produces a bell-shaped response centered at :math:`\mu` with
width controlled by :math:`\sigma` and amplitude scaled by :math:`g`.
**2. NEST parameter coupling (critical implementation detail)**
NEST's ``gauss_rate_ipn`` model uses the same parameter names ``mu`` and
``sigma`` for two distinct purposes:
1. **SDE parameters**: ``mu`` is the mean drive in the drift term;
``sigma`` scales the diffusion coefficient (input noise strength).
2. **Gain-function parameters**: ``mu`` is the Gaussian center (location of
peak response); ``sigma`` is the Gaussian width (standard deviation of
the bell curve).
This implementation preserves NEST's dual-role design. Consequently:
- The default ``sigma=0.0`` from NEST (no input noise) is retained.
- When ``sigma=0``, the gain function becomes undefined at ``h=mu`` (0/0 form),
potentially producing ``NaN`` values, matching NEST behavior.
- Both roles share the same parameter instance, so changes affect both the
SDE noise term and the gain-function shape.
**3. Update ordering (matching NEST ``rate_neuron_ipn_impl.h``)**
Per simulation step of duration ``dt``:
1. **Store outgoing delayed value**: Current ``rate`` becomes ``delayed_rate``.
2. **Draw noise sample**: Compute ``noise = sigma * xi`` where ``xi ~ N(0,1)``.
3. **Propagate intrinsic dynamics**: Apply stochastic exponential Euler
(reduces to Euler-Maruyama when ``lambda=0``):
.. math::
X_{\mathrm{new}} = e^{-\lambda h/\tau}X_{\mathrm{prev}}
+ \frac{1-e^{-\lambda h/\tau}}{\lambda}(\mu + \mu_{\mathrm{ext}})
+ \sqrt{\frac{1-e^{-2\lambda h/\tau}}{2\lambda}}\,\sigma\xi,
where :math:`h=dt` and special handling applies when ``lambda=0``.
4. **Drain delayed queues**: Retrieve and sum delayed excitatory/inhibitory
contributions scheduled for the current step.
5. **Process instantaneous events**: Parse and accumulate ``instant_rate_events``
and zero-delay entries from ``delayed_rate_events``.
6. **Apply Gaussian gain function**:
- ``linear_summation=True``: Sum all network inputs, then apply :math:`\phi`.
- ``linear_summation=False``: Apply :math:`\phi` to each event value before
summation (nonlinearity applied during event buffering).
7. **Include multiplicative coupling** (if enabled): Scale excitatory/inhibitory
branches by state-dependent factors (trivially ``1.0`` for this model).
8. **Apply rectification** (if enabled): Clamp ``rate >= rectify_rate``.
9. **Store outgoing instantaneous value**: Updated ``rate`` becomes
``instant_rate`` for immediate event transmission.
**4. Assumptions and constraints**
Mathematical validity:
- ``tau > 0`` (time constant must be positive).
- ``lambda >= 0`` (passive decay rate must be non-negative).
- ``sigma >= 0`` (noise/gain width cannot be negative).
- When ``sigma=0``, the gain function is undefined at ``h=mu``, matching NEST's
potential NaN generation.
Event semantics:
- Events are specified as ``(rate, weight)`` tuples, ``(rate, weight, delay_steps)``
triples, ``(rate, weight, delay_steps, multiplicity)`` 4-tuples, or dicts with
``'rate'``, ``'weight'``, ``'delay_steps'``, ``'multiplicity'`` keys.
- ``instant_rate_events`` must have ``delay_steps=0`` (enforced with exception).
- ``delayed_rate_events`` support integer ``delay_steps >= 0``.
- Negative weights create inhibitory contributions (sign-based routing).
**5. Computational implications**
Integration method: Stochastic exponential Euler is exact for linear drift with
additive noise (Ornstein-Uhlenbeck process) but approximate when network input
is present. Accuracy degrades if ``dt`` is not sufficiently small relative to
``tau/lambda``.
Delay queue management: Each delayed event is stored in a dictionary keyed by
target step index. Memory scales with the number of active delayed events.
Unbounded delays can lead to memory growth.
Gaussian evaluation: Computing :math:`\exp(-(h-\mu)^2/(2\sigma^2))` per event
(when ``linear_summation=False``) or per step (when ``linear_summation=True``)
is vectorized via NumPy. For ``sigma=0``, evaluations at ``h=mu`` produce NaN.
Parameters
----------
in_size : Size
Population shape specification. Determines ``self.varshape`` and the shape
of state variables ``rate``, ``noise``, etc. Can be an integer (1D population)
or tuple of integers (multi-dimensional population).
tau : Quantity[ms], optional
Time constant :math:`\tau` of rate dynamics. Must be positive. Controls the
temporal scale of both drift and diffusion terms. Default ``10 ms``.
lambda_ : float, optional
Passive decay rate :math:`\lambda \ge 0`. When ``lambda=0``, dynamics reduce
to driftless Brownian motion with external drive. Larger values produce
stronger relaxation toward the mean drive. Default ``1.0``.
sigma : float, optional
Shared dual-role parameter (matching NEST):
1. **Diffusion coefficient**: Scales input noise as :math:`\sqrt{\tau}\sigma dW(t)`.
2. **Gaussian width**: Standard deviation of the gain function :math:`\phi(h)`.
Must be non-negative. NEST default ``0.0`` (no noise, but gain function
becomes undefined at ``h=mu``). Default ``0.0``.
mu : float, optional
Shared dual-role parameter (matching NEST):
1. **Mean drive**: Constant drift term in the SDE.
2. **Gaussian center**: Location of peak response in :math:`\phi(h)`.
Default ``0.0``.
g : float, optional
Gain amplitude parameter. Scales the maximum value of the Gaussian nonlinearity
:math:`\phi(h)`. When ``g=1``, peak response is 1.0 at ``h=mu``.
Default ``1.0``.
mult_coupling : bool, optional
Enable multiplicative coupling (state-dependent input scaling). For
``gauss_rate_ipn``, the coupling factors are trivially ``1.0`` (no effect),
but the parameter is retained for NEST API compatibility. Default ``False``.
linear_summation : bool, optional
NEST switch controlling where the Gaussian nonlinearity is applied:
- ``True`` (default): Sum all network inputs first, then apply :math:`\phi`
to the total. Results in :math:`\phi(\sum h_i w_i)`.
- ``False``: Apply :math:`\phi` to each event's rate value during buffering,
then sum the transformed contributions. Results in :math:`\sum \phi(h_i) w_i`.
Default ``True``.
rectify_rate : float, optional
Lower bound for output clamping when ``rectify_output=True``. Must be
non-negative. Default ``0.0``.
rectify_output : bool, optional
If ``True``, apply rectification ``rate = max(rate, rectify_rate)`` after
all updates. Prevents negative firing rates. Default ``False``.
rate_initializer : Callable, optional
Initializer for the ``rate`` state variable. Called with ``(shape, batch_size)``
to produce initial firing rates. Default ``braintools.init.Constant(0.0)``.
noise_initializer : Callable, optional
Initializer for the ``noise`` state variable (stores last noise sample).
Default ``braintools.init.Constant(0.0)``.
name : str or None, optional
Module identifier. Default ``None``.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to NEST and model symbols
:header-rows: 1
:widths: 20 18 18 44
* - Parameter
- Default
- Math symbol
- Semantics / NEST correspondence
* - ``tau``
- ``10 ms``
- :math:`\tau`
- Time constant of rate dynamics (NEST: ``tau``).
* - ``lambda_``
- ``1.0``
- :math:`\lambda`
- Passive decay rate (NEST: ``lambda`` in template parameters).
* - ``sigma``
- ``0.0``
- :math:`\sigma`
- Dual role: input-noise scale in SDE **and** Gaussian width in :math:`\phi(h)` (NEST: ``sigma``).
* - ``mu``
- ``0.0``
- :math:`\mu`
- Dual role: mean drive in SDE **and** Gaussian center in :math:`\phi(h)` (NEST: ``mu``).
* - ``g``
- ``1.0``
- :math:`g`
- Gain amplitude of Gaussian nonlinearity (NEST: ``g`` in template nonlinearity).
* - ``rectify_rate``
- ``0.0``
- :math:`r_{\mathrm{min}}`
- Lower bound for output rectification (NEST: ``rectify_rate``).
* - ``rectify_output``
- ``False``
- —
- Enable output clamping to :math:`\ge r_{\mathrm{min}}` (NEST: ``rectify_output``).
* - ``linear_summation``
- ``True``
- —
- Apply :math:`\phi` to sum (``True``) vs. per-event (``False``) (NEST: ``linear_summation``).
* - ``mult_coupling``
- ``False``
- —
- Enable multiplicative coupling (no-op for this model, NEST compatibility only).
Attributes
----------
rate : brainstate.ShortTermState
Current firing rate of shape ``self.varshape``. Updated each step.
noise : brainstate.ShortTermState
Last noise sample :math:`\sigma\xi` of shape ``self.varshape``.
instant_rate : brainstate.ShortTermState
Copy of ``rate`` after update, used for zero-delay event transmission.
delayed_rate : brainstate.ShortTermState
Copy of ``rate`` before update, used for non-zero delay event transmission.
Notes
-----
**Runtime event semantics**:
- ``instant_rate_events``: Applied in the current step with zero delay.
Format: scalar, ``(rate, weight)``, ``(rate, weight, 0)``,
``(rate, weight, 0, multiplicity)``, or dict with keys
``'rate'``, ``'weight'``, ``'delay_steps'`` (must be 0), ``'multiplicity'``.
- ``delayed_rate_events``: Scheduled for future delivery based on ``delay_steps``.
Format: same as above, but ``delay_steps`` can be any non-negative integer.
- ``x``: External current input (additive to ``mu``), summed via
``sum_current_inputs(x, rate)``.
**Failure modes**:
- **NaN generation**: When ``sigma=0`` and network input ``h`` exactly equals
``mu``, the Gaussian :math:`\phi(h) = g \exp(0/0)` is undefined. NEST also
produces NaN in this case.
- **Non-increasing ``amplitude_times``**: Raises ``ValueError`` during construction
if delay queues are misconfigured (internal logic error).
- **Invalid event delays**: ``instant_rate_events`` with non-zero ``delay_steps``
raise ``ValueError``. Negative delays in ``delayed_rate_events`` also raise
``ValueError``.
**Relationship to other models**:
- ``gauss_rate_ipn`` is the NEST input-noise template instantiated with Gaussian
nonlinearities. The base template ``rate_neuron_ipn`` supports arbitrary input
nonlinearities and multiplicative-coupling functions.
- ``gauss_rate_opn`` is the output-noise variant (noise added after nonlinearity).
- For linear gain (``g * h``), use ``lin_rate_ipn`` instead.
Examples
--------
Minimal usage with default parameters:
.. code-block:: python
>>> import brainpy_state as bpst
>>> import saiunit as u
>>> import brainstate
>>> brainstate.environ.set_dt(0.1 * u.ms)
>>> neuron = bpst.gauss_rate_ipn(in_size=100)
>>> neuron.init_all_states()
>>> # Simulate 10 steps with no input
>>> for _ in range(10):
... rate = neuron.update()
With network events and external drive:
.. code-block:: python
>>> neuron = bpst.gauss_rate_ipn(
... in_size=50,
... tau=20.0 * u.ms,
... lambda_=1.5,
... sigma=0.5,
... mu=0.0,
... g=2.0,
... linear_summation=True,
... rectify_output=True,
... rectify_rate=0.0,
... )
>>> neuron.init_all_states()
>>> # Apply instantaneous rate input and delayed event
>>> rate = neuron.update(
... x=1.0, # external drive
... instant_rate_events=(0.5, 1.0), # (rate, weight)
... delayed_rate_events=(1.0, 2.0, 5), # (rate, weight, delay_steps)
... )
References
----------
.. [1] Hahne J, Dahmen D, Schuecker J, Frommer A, Bolten M, Helias M,
Diesmann M (2017). Integration of continuous-time dynamics in a
spiking neural network simulator. *Frontiers in Neuroinformatics*, 11:34.
DOI: `10.3389/fninf.2017.00034 <https://doi.org/10.3389/fninf.2017.00034>`_.
.. [2] Hahne J, Helias M, Kunkel S, Igarashi J, Bolten M, Frommer A,
Diesmann M (2015). A unified framework for spiking and gap-junction
interactions in distributed neuronal network simulations.
*Frontiers in Neuroinformatics*, 9:22.
DOI: `10.3389/fninf.2015.00022 <https://doi.org/10.3389/fninf.2015.00022>`_.
.. [3] NEST Documentation: ``gauss_rate_ipn`` model.
https://nest-simulator.readthedocs.io/en/stable/models/gauss_rate_ipn.html
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
tau: ArrayLike = 10.0 * u.ms,
lambda_: ArrayLike = 1.0,
sigma: ArrayLike = 0.0,
mu: ArrayLike = 0.0,
g: ArrayLike = 1.0,
mult_coupling: bool = False,
linear_summation: bool = True,
rectify_rate: ArrayLike = 0.0,
rectify_output: bool = False,
rate_initializer: Callable = braintools.init.Constant(0.0),
noise_initializer: Callable = braintools.init.Constant(0.0),
name: str = None,
):
super().__init__(
in_size=in_size,
tau=tau,
sigma=sigma,
mu=mu,
g=g,
mult_coupling=mult_coupling,
g_ex=1.0,
g_in=1.0,
theta_ex=0.0,
theta_in=0.0,
linear_summation=linear_summation,
rate_initializer=rate_initializer,
noise_initializer=noise_initializer,
name=name,
)
self.lambda_ = braintools.init.param(lambda_, self.varshape)
self.rectify_rate = braintools.init.param(rectify_rate, self.varshape)
self.rectify_output = bool(rectify_output)
self._validate_parameters()
@property
def recordables(self):
r"""List of state variable names that can be recorded during simulation.
Returns
-------
list of str
``['rate', 'noise']`` — firing rate and noise sample state variables.
Notes
-----
These names correspond to attributes that can be monitored by recording
devices or logged during simulation. Accessing ``neuron.rate.value`` and
``neuron.noise.value`` retrieves the current values.
"""
return ['rate', 'noise']
@property
def receptor_types(self):
r"""Dictionary mapping receptor port names to integer indices.
Returns
-------
dict
``{'RATE': 0}`` — single receptor type for rate-based input.
Notes
-----
NEST uses receptor types to distinguish synaptic input channels (e.g.,
AMPA, NMDA, GABA). For ``gauss_rate_ipn``, only one generic ``'RATE'``
receptor is defined. This is used for NEST API compatibility but has no
functional effect in this implementation (excitatory/inhibitory routing
is based on weight sign, not receptor type).
"""
return {'RATE': 0}
def _validate_parameters(self):
r"""Check parameter validity and raise exceptions for invalid configurations.
Enforces mathematical and physical constraints on model parameters:
- ``tau > 0`` (time constant must be positive)
- ``lambda >= 0`` (passive decay rate must be non-negative)
- ``sigma >= 0`` (noise/gain width must be non-negative)
- ``rectify_rate >= 0`` (lower rectification bound must be non-negative)
Raises
------
ValueError
If any parameter violates its constraint, with a descriptive message
indicating which parameter is invalid.
Notes
-----
This method is called automatically during ``__init__``. It does not validate
the ``sigma=0`` special case (undefined gain function at ``h=mu``), as NEST
permits this configuration despite potential NaN generation.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.tau, self.sigma)):
return
if np.any(self.tau <= 0.0 * u.ms):
raise ValueError('Time constant tau must be > 0.')
if np.any(self.lambda_ < 0.0):
raise ValueError('Passive decay rate lambda must be >= 0.')
if np.any(self.sigma < 0.0):
raise ValueError('Noise parameter sigma must be >= 0.')
if np.any(self.rectify_rate < 0.0):
raise ValueError('Rectifying rate must be >= 0.')
[docs]
def init_state(self, **kwargs):
r"""Initialize all state variables and internal delay queues.
Creates and initializes firing rate, noise, and auxiliary state variables
required for event-driven simulation with delayed transmission.
Parameters
----------
**kwargs
Unused compatibility parameters accepted by the base-state API.
Notes
-----
This method initializes:
- ``self.rate``: Current firing rate, initialized via ``rate_initializer``.
- ``self.noise``: Last noise sample, initialized via ``noise_initializer``.
- ``self.instant_rate``: Copy of ``rate`` for zero-delay transmission.
- ``self.delayed_rate``: Copy of ``rate`` for delayed transmission.
- ``self._step_count``: Internal step counter (int64 scalar).
- ``self._delayed_ex_queue``: Dictionary ``{step_idx: excitatory_contribution}``
for delayed excitatory events.
- ``self._delayed_in_queue``: Dictionary ``{step_idx: inhibitory_contribution}``
for delayed inhibitory events.
All queues are empty at initialization. The step counter starts at 0.
Examples
--------
.. code-block:: python
>>> neuron = bpst.gauss_rate_ipn(in_size=100)
>>> neuron.init_state()
>>> neuron.rate.value.shape
(100,)
"""
rate = braintools.init.param(self.rate_initializer, self.varshape)
noise = braintools.init.param(self.noise_initializer, self.varshape)
rate_np = self._to_numpy(rate)
noise_np = self._to_numpy(noise)
self.rate = brainstate.ShortTermState(rate_np)
self.noise = brainstate.ShortTermState(noise_np)
dftype = brainstate.environ.dftype()
self.instant_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True))
self.delayed_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True))
self._step_count = 0
self._delayed_ex_queue = {}
self._delayed_in_queue = {}
[docs]
def update(self, x=0.0, instant_rate_events=None, delayed_rate_events=None, noise=None):
r"""Advance dynamics by one time step using stochastic exponential Euler.
Implements the complete NEST ``gauss_rate_ipn`` update cycle: store delayed
output, draw noise, integrate SDE, process events, apply Gaussian nonlinearity,
and update firing rate.
Parameters
----------
x : ArrayLike, optional
External additive drive (current input), broadcast to ``self.varshape``.
Added to ``mu`` in the drift term. Can be scalar, array-like, or have
``saiunit`` units (automatically converted). Default ``0.0``.
instant_rate_events : scalar, tuple, list of tuples, or None, optional
Rate events applied in the current step with zero delay. Each event can be:
- Scalar: Interpreted as ``(value, weight=1.0)``.
- ``(rate, weight)``: Rate value and synaptic weight.
- ``(rate, weight, delay_steps)``: Must have ``delay_steps=0`` (raises
``ValueError`` otherwise).
- ``(rate, weight, delay_steps, multiplicity)``: 4-tuple with multiplicity
factor.
- Dict with keys ``'rate'``, ``'weight'``, ``'delay_steps'``,
``'multiplicity'``.
Weights are signed: positive for excitatory, negative for inhibitory.
Default ``None`` (no events).
delayed_rate_events : scalar, tuple, list of tuples, or None, optional
Rate events scheduled for future delivery based on ``delay_steps``.
Format is the same as ``instant_rate_events``, but ``delay_steps`` can be
any non-negative integer. Zero-delay events are applied immediately.
Negative delays raise ``ValueError``. Default ``None``.
noise : ArrayLike or None, optional
Optional external noise sample :math:`\xi` to use instead of drawing from
:math:`N(0,1)`. Must be broadcast-compatible with ``self.varshape``.
When ``None``, standard normal noise is drawn internally. Useful for
reproducible testing. Default ``None``.
Returns
-------
rate_new : ndarray
Updated firing rate of shape matching ``self.rate.value.shape``, after
applying all dynamics, network input, Gaussian nonlinearity, multiplicative
coupling, and optional rectification.
Raises
------
ValueError
- If any ``instant_rate_events`` entry specifies non-zero ``delay_steps``.
- If any ``delayed_rate_events`` entry has negative ``delay_steps``.
Notes
-----
**Integration method**: Stochastic exponential Euler for the linear part of
the SDE, with network input and Gaussian nonlinearity applied as an additive
perturbation scaled by the integration factor ``P2``.
**Update propagation coefficients**:
- ``P1 = exp(-lambda * dt / tau)``: State persistence factor.
- ``P2 = (1 - exp(-lambda * dt / tau)) / lambda``: Input integration factor
(reduces to ``dt / tau`` when ``lambda=0``).
- ``input_noise_factor = sqrt((1 - exp(-2*lambda*dt/tau)) / (2*lambda))``:
Diffusion coefficient (reduces to ``sqrt(dt / tau)`` when ``lambda=0``).
**Gaussian nonlinearity application**:
- ``linear_summation=True``: Compute ``phi(sum(excitatory) + sum(inhibitory))``.
- ``linear_summation=False``: Each event's rate is transformed during buffering,
so summed values already include ``phi`` applied per event.
**Multiplicative coupling**: For ``gauss_rate_ipn``, factors ``H_ex`` and
``H_in`` are trivially ``1.0`` (no-op), but the code path is present for
NEST compatibility.
**Rectification**: If ``rectify_output=True``, the final rate is clamped to
``max(rate_new, rectify_rate)``.
**State side effects**: Updates ``self.rate``, ``self.noise``,
``self.delayed_rate``, ``self.instant_rate``, ``self._step_count``, and
modifies delay queues ``self._delayed_ex_queue`` and ``self._delayed_in_queue``.
"""
h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms))
state_shape = self.rate.value.shape
tau, sigma, mu, g = self._common_parameters_gauss(state_shape)
lambda_ = self._broadcast_to_state(self._to_numpy(self.lambda_), state_shape)
rectify_rate = self._broadcast_to_state(self._to_numpy(self.rectify_rate), state_shape)
state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext = self._common_inputs_gauss(
x=x,
instant_rate_events=instant_rate_events,
delayed_rate_events=delayed_rate_events,
g=g,
mu=mu,
sigma=sigma,
)
dftype = brainstate.environ.dftype()
rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape)
if noise is None:
xi = jnp.asarray(np.random.normal(size=state_shape), dtype=dftype)
else:
xi = jnp.broadcast_to(jnp.asarray(noise, dtype=dftype), state_shape)
noise_now = sigma * xi
if np.any(lambda_ > 0.0):
P1 = np.exp(-lambda_ * h / tau)
P2 = -np.expm1(-lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_)
input_noise_factor = np.sqrt(
-0.5 * np.expm1(-2.0 * lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_)
)
zero_lambda = lambda_ == 0.0
if np.any(zero_lambda):
P1 = np.where(zero_lambda, 1.0, P1)
P2 = np.where(zero_lambda, h / tau, P2)
input_noise_factor = np.where(zero_lambda, np.sqrt(h / tau), input_noise_factor)
else:
P1 = np.ones_like(lambda_)
P2 = h / tau
input_noise_factor = np.sqrt(h / tau)
mu_total = mu + mu_ext
rate_new = P1 * rate_prev + P2 * mu_total + input_noise_factor * noise_now
H_ex = jnp.ones_like(rate_prev)
H_in = jnp.ones_like(rate_prev)
if self.mult_coupling:
H_ex = self._mult_coupling_ex(rate_prev)
H_in = self._mult_coupling_in(rate_prev)
if self.linear_summation:
if self.mult_coupling:
rate_new += P2 * H_ex * self._input(delayed_ex + instant_ex, g, mu, sigma)
rate_new += P2 * H_in * self._input(delayed_in + instant_in, g, mu, sigma)
else:
rate_new += P2 * self._input(delayed_ex + instant_ex + delayed_in + instant_in, g, mu, sigma)
else:
# Nonlinear transform has already been applied per event in buffer handling.
rate_new += P2 * H_ex * (delayed_ex + instant_ex)
rate_new += P2 * H_in * (delayed_in + instant_in)
if self.rectify_output:
rate_new = jnp.where(rate_new < rectify_rate, rectify_rate, rate_new)
self.rate.value = rate_new
self.noise.value = noise_now
self.delayed_rate.value = rate_prev
self.instant_rate.value = rate_new
self._step_count = step_idx + 1
return rate_new