Source code for brainpy_state._nest.threshold_lin_rate

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


from typing import Callable

import brainstate
import braintools
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from brainpy_state._nest.lin_rate import _lin_rate_base
from ._utils import is_tracer

__all__ = [
    'threshold_lin_rate_ipn',
    'threshold_lin_rate_opn',
]


class _threshold_lin_rate_base(_lin_rate_base):
    __module__ = 'brainpy.state'

    def _input(self, h, g, theta, alpha):
        return np.minimum(np.maximum(g * (h - theta), 0.0), alpha)

    @staticmethod
    def _mult_coupling_ex(rate):
        dftype = brainstate.environ.dftype()
        return jnp.ones_like(rate, dtype=dftype)

    @staticmethod
    def _mult_coupling_in(rate):
        dftype = brainstate.environ.dftype()
        return jnp.ones_like(rate, dtype=dftype)

    def _extract_event_fields(self, ev, default_delay_steps: int):
        if isinstance(ev, dict):
            rate = ev.get('rate', ev.get('coeff', ev.get('value', 0.0)))
            weight = ev.get('weight', 1.0)
            multiplicity = ev.get('multiplicity', 1.0)
            delay_steps = ev.get('delay_steps', ev.get('delay', default_delay_steps))
        elif isinstance(ev, (tuple, list)):
            if len(ev) == 2:
                rate, weight = ev
                delay_steps = default_delay_steps
                multiplicity = 1.0
            elif len(ev) == 3:
                rate, weight, delay_steps = ev
                multiplicity = 1.0
            elif len(ev) == 4:
                rate, weight, delay_steps, multiplicity = ev
            else:
                raise ValueError('Rate event tuples must have length 2, 3, or 4.')
        else:
            rate = ev
            weight = 1.0
            multiplicity = 1.0
            delay_steps = default_delay_steps

        delay_steps = self._to_int_scalar(delay_steps, name='delay_steps')
        return rate, weight, multiplicity, delay_steps

    def _event_to_ex_in(self, ev, default_delay_steps: int, state_shape, g, theta, alpha):
        rate, weight, multiplicity, delay_steps = self._extract_event_fields(ev, default_delay_steps)

        rate_np = self._broadcast_to_state(self._to_numpy(rate), state_shape)
        weight_np = self._broadcast_to_state(self._to_numpy(weight), state_shape)
        multiplicity_np = self._broadcast_to_state(self._to_numpy(multiplicity), state_shape)
        dftype = brainstate.environ.dftype()
        weight_sign = self._broadcast_to_state(
            np.asarray(u.get_mantissa(weight), dtype=dftype) >= 0.0,
            state_shape,
        )

        if self.linear_summation:
            weighted_value = rate_np * weight_np * multiplicity_np
        else:
            weighted_value = self._input(rate_np, g, theta, alpha) * weight_np * multiplicity_np

        ex = np.where(weight_sign, weighted_value, 0.0)
        inh = np.where(weight_sign, 0.0, weighted_value)
        return ex, inh, delay_steps

    def _accumulate_instant_events_threshold(self, events, state_shape, g, theta, alpha):
        dftype = brainstate.environ.dftype()
        ex = np.zeros(state_shape, dtype=dftype)
        inh = np.zeros(state_shape, dtype=dftype)
        for ev in self._coerce_events(events):
            ex_i, inh_i, delay_steps = self._event_to_ex_in(
                ev,
                default_delay_steps=0,
                state_shape=state_shape,
                g=g,
                theta=theta,
                alpha=alpha,
            )
            if delay_steps != 0:
                raise ValueError('instant_rate_events must not specify non-zero delay_steps.')
            ex += ex_i
            inh += inh_i
        return ex, inh

    def _schedule_delayed_events_threshold(self, events, step_idx: int, state_shape, g, theta, alpha):
        dftype = brainstate.environ.dftype()
        ex_now = np.zeros(state_shape, dtype=dftype)
        inh_now = np.zeros(state_shape, dtype=dftype)
        for ev in self._coerce_events(events):
            ex_i, inh_i, delay_steps = self._event_to_ex_in(
                ev,
                default_delay_steps=1,
                state_shape=state_shape,
                g=g,
                theta=theta,
                alpha=alpha,
            )
            if delay_steps < 0:
                raise ValueError('delay_steps for delayed_rate_events must be >= 0.')
            if delay_steps == 0:
                ex_now += ex_i
                inh_now += inh_i
            else:
                target_step = step_idx + delay_steps
                self._queue_add(self._delayed_ex_queue, target_step, ex_i)
                self._queue_add(self._delayed_in_queue, target_step, inh_i)
        return ex_now, inh_now

    def _common_inputs_threshold(self, x, instant_rate_events, delayed_rate_events, g, theta, alpha):
        state_shape = self.rate.value.shape
        step_idx = self._step_count

        delayed_ex, delayed_in = self._drain_delayed_queue(step_idx, state_shape)
        delayed_ex_now, delayed_in_now = self._schedule_delayed_events_threshold(
            delayed_rate_events,
            step_idx=step_idx,
            state_shape=state_shape,
            g=g,
            theta=theta,
            alpha=alpha,
        )
        delayed_ex = delayed_ex + delayed_ex_now
        delayed_in = delayed_in + delayed_in_now

        instant_ex, instant_in = self._accumulate_instant_events_threshold(
            instant_rate_events,
            state_shape=state_shape,
            g=g,
            theta=theta,
            alpha=alpha,
        )

        delta_input = self._broadcast_to_state(self._to_numpy(self.sum_delta_inputs(0.0)), state_shape)
        instant_ex += np.where(delta_input > 0.0, delta_input, 0.0)
        instant_in += np.where(delta_input < 0.0, delta_input, 0.0)

        mu_ext = self._broadcast_to_state(self._to_numpy(self.sum_current_inputs(x, self.rate.value)), state_shape)

        return state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext

    def _common_parameters_threshold(self, state_shape):
        tau = self._broadcast_to_state(self._to_numpy_ms(self.tau), state_shape)
        sigma = self._broadcast_to_state(self._to_numpy(self.sigma), state_shape)
        mu = self._broadcast_to_state(self._to_numpy(self.mu), state_shape)
        g = self._broadcast_to_state(self._to_numpy(self.g), state_shape)
        theta = self._broadcast_to_state(self._to_numpy(self.theta), state_shape)
        alpha = self._broadcast_to_state(self._to_numpy(self.alpha), state_shape)
        return tau, sigma, mu, g, theta, alpha


class threshold_lin_rate_ipn(_threshold_lin_rate_base):
    r"""NEST-compatible input-noise threshold-linear rate neuron.

    Implements the NEST ``threshold_lin_rate_ipn`` model, an input-noise rate neuron
    with threshold-linear gain function. This model provides a piecewise-linear
    activation with lower and upper saturation bounds, commonly used for modeling
    neural populations with firing rate constraints and additive stochastic drive.

    Mathematical Description
    ------------------------

    **1. Continuous-Time Stochastic Dynamics**

    The rate state :math:`X(t)` evolves according to the Langevin equation:

    .. math::

       \tau\,dX(t) = [-\lambda X(t) + \mu + I_\mathrm{net}(t)]\,dt
       + \sqrt{\tau}\,\sigma\,dW(t),

    where:

    - :math:`\tau > 0` is the time constant (ms).
    - :math:`\lambda \ge 0` is the passive decay rate (dimensionless). Controls
      exponential relaxation; :math:`\lambda=0` yields driftless diffusion.
    - :math:`\mu` is the mean drive (dimensionless, external constant input).
    - :math:`\sigma \ge 0` is the input-noise strength (dimensionless).
    - :math:`W(t)` is a standard Wiener process.
    - :math:`I_\mathrm{net}(t)` is the network input (see below).

    **2. Threshold-Linear Gain Function**

    The input nonlinearity :math:`\phi(h)` is a threshold-linear function with
    saturation:

    .. math::

       \phi(h) = \min(\max(g(h-\theta), 0), \alpha),

    where:

    - :math:`g > 0` is the gain slope (dimensionless).
    - :math:`\theta` is the activation threshold (dimensionless).
    - :math:`\alpha > 0` is the saturation level (dimensionless).

    This function is zero for :math:`h < \theta`, linear with slope :math:`g` for
    :math:`\theta \le h < \theta + \alpha/g`, and saturates at :math:`\alpha` for
    :math:`h \ge \theta + \alpha/g`.

    **3. Network Input Structure**

    The network input :math:`I_\mathrm{net}(t)` is computed according to:

    .. math::

       I_\mathrm{net}(t) = \phi(I_\mathrm{ex}(t) + I_\mathrm{in}(t))
       \quad\text{(if linear\_summation=True)},

    or:

    .. math::

       I_\mathrm{net}(t) = \phi(I_\mathrm{ex}(t)) + \phi(I_\mathrm{in}(t))
       \quad\text{(if linear\_summation=False)},

    where :math:`I_\mathrm{ex}(t)` and :math:`I_\mathrm{in}(t)` are excitatory and
    inhibitory branches (sign-separated by event weight).

    **Note**: Unlike the base ``rate_neuron_ipn`` model, multiplicative coupling
    :math:`H_\mathrm{ex}(X)`, :math:`H_\mathrm{in}(X)` is **not** supported for
    threshold-linear neurons in NEST. The ``mult_coupling`` parameter is accepted
    for API compatibility but has no effect on dynamics (coupling factors are
    constant 1.0).

    **4. Discrete-Time Integration (Stochastic Exponential Euler)**

    For time step :math:`h=dt` (in ms), the model uses exact Ornstein-Uhlenbeck
    integration for the linear part:

    .. math::

       X_{n+1} = P_1 X_n + P_2 (\mu + I_\mathrm{net,n}) + N\,\xi_n,

    where :math:`\xi_n\sim\mathcal{N}(0,1)` is standard Gaussian noise.

    **For** :math:`\lambda > 0`:

    .. math::

       P_1 = \exp\left(-\frac{\lambda h}{\tau}\right), \quad
       P_2 = \frac{1-P_1}{\lambda}, \quad
       N = \sigma\sqrt{\frac{1-P_1^2}{2\lambda}}.

    **For** :math:`\lambda = 0` (Euler-Maruyama):

    .. math::

       P_1=1, \quad P_2=\frac{h}{\tau}, \quad N=\sigma\sqrt{\frac{h}{\tau}}.

    **5. Update Ordering (Matching NEST ``rate_neuron_ipn_impl.h``)**

    Per simulation step:

    1. **Store outgoing delayed value**: current ``rate`` is recorded as
       ``delayed_rate``.
    2. **Draw noise**: sample :math:`\xi_n\sim\mathcal{N}(0,1)`, compute
       :math:`\mathrm{noise}_n=\sigma\,\xi_n`.
    3. **Propagate intrinsic dynamics**: apply stochastic exponential Euler to
       :math:`X_n` with external drive and noise.
    4. **Read event buffers**: drain delayed events arriving at current step;
       accumulate instantaneous events.
    5. **Apply network input with threshold-linear gain**:

       - ``linear_summation=True``: nonlinearity applied to summed branch input
         during update: :math:`I_\mathrm{net}=\phi(I_\mathrm{ex}+I_\mathrm{in})`.
       - ``linear_summation=False``: nonlinearity applied per event during
         buffering: :math:`I_\mathrm{net}=\phi(I_\mathrm{ex})+\phi(I_\mathrm{in})`.

    6. **Rectification** (optional): if ``rectify_output=True``, clamp
       :math:`X_{n+1}\gets\max(X_{n+1},\,\mathrm{rectify\_rate})`.
    7. **Update state variables**: ``rate``, ``noise``, ``delayed_rate``,
       ``instant_rate``, ``_step_count``.

    **6. Numerical Stability and Computational Complexity**

    - Construction enforces :math:`\tau>0`, :math:`\lambda\ge 0`,
      :math:`\sigma\ge 0`, :math:`\mathrm{rectify\_rate}\ge 0`.
    - The threshold-linear gain is evaluated using ``np.minimum`` and ``np.maximum``
      for numerically stable clipping.
    - Per-call cost is :math:`O(\prod\mathrm{varshape})` with vectorized NumPy
      operations in float64.
    - The exponential Euler scheme is numerically stable for all :math:`h>0`.

    Parameters
    ----------
    in_size : Size
        Population shape (tuple or int). All per-neuron parameters are broadcast
        to ``self.varshape``.
    tau : ArrayLike, optional
        Time constant :math:`\tau` (ms). Scalar or array broadcastable to
        ``self.varshape``. Must be :math:`>0`. Default: ``10.0 * u.ms``.
    lambda_ : ArrayLike, optional
        Passive decay rate :math:`\lambda` (dimensionless). Scalar or array
        broadcastable to ``self.varshape``. Must be :math:`\ge 0`. Controls
        exponential relaxation (:math:`\lambda=0` yields driftless diffusion).
        Default: ``1.0``.
    sigma : ArrayLike, optional
        Input-noise scale :math:`\sigma` (dimensionless). Scalar or array
        broadcastable to ``self.varshape``. Must be :math:`\ge 0`. Default:
        ``1.0``.
    mu : ArrayLike, optional
        Mean drive :math:`\mu` (dimensionless). Scalar or array broadcastable to
        ``self.varshape``. External constant input to the rate dynamics. Default:
        ``0.0``.
    g : ArrayLike, optional
        Gain slope :math:`g` (dimensionless) for the threshold-linear function
        :math:`\phi(h)=\min(\max(g(h-\theta),0),\alpha)`. Scalar or array
        broadcastable to ``self.varshape``. Default: ``1.0``.
    theta : ArrayLike, optional
        Activation threshold :math:`\theta` (dimensionless). The gain function is
        zero for :math:`h<\theta`. Scalar or array broadcastable to
        ``self.varshape``. Default: ``0.0``.
    alpha : ArrayLike, optional
        Saturation level :math:`\alpha` (dimensionless). The gain function
        saturates at :math:`\alpha` for large inputs. Scalar or array broadcastable
        to ``self.varshape``. Default: ``np.inf`` (no saturation).
    mult_coupling : bool, optional
        API compatibility flag. Has **no effect** on dynamics for threshold-linear
        neurons (multiplicative coupling factors are constant 1.0). Default:
        ``False``.
    linear_summation : bool, optional
        Controls where the threshold-linear gain is applied. If ``True``, the gain
        is applied to the sum of excitatory and inhibitory inputs. If ``False``,
        the gain is applied separately to each input branch (matching NEST event
        semantics). Default: ``True``.
    rectify_rate : ArrayLike, optional
        Lower bound :math:`X_\mathrm{min}` for the rate when
        ``rectify_output=True`` (dimensionless). Scalar or array broadcastable to
        ``self.varshape``. Must be :math:`\ge 0`. Default: ``0.0``.
    rectify_output : bool, optional
        If ``True``, clamp the rate output to
        :math:`X\ge\mathrm{rectify\_rate}` after each update step. Default:
        ``False``.
    rate_initializer : Callable, optional
        Initializer for the ``rate`` state variable :math:`X_0`. Callable
        compatible with ``braintools.init`` API. Default:
        ``braintools.init.Constant(0.0)``.
    noise_initializer : Callable, optional
        Initializer for the ``noise`` state variable (records last noise sample
        :math:`\sigma\,\xi_{n-1}`). Callable compatible with ``braintools.init``
        API. Default: ``braintools.init.Constant(0.0)``.
    name : str or None, optional
        Module name for identification in hierarchies. If ``None``, an
        auto-generated name is used. Default: ``None``.

    Parameter Mapping
    -----------------

    The following table maps NEST ``threshold_lin_rate_ipn`` parameters to
    brainpy.state equivalents:

    =============================== ===================== =========
    NEST Parameter                  brainpy.state         Default
    =============================== ===================== =========
    ``tau``                         ``tau``               10 ms
    ``lambda``                      ``lambda_``           1.0
    ``sigma``                       ``sigma``             1.0
    ``mu``                          ``mu``                0.0
    ``g`` (gain slope)              ``g``                 1.0
    ``theta`` (threshold)           ``theta``             0.0
    ``alpha`` (saturation)          ``alpha``             inf
    ``mult_coupling``               ``mult_coupling``     False
                                    (no effect)
    ``linear_summation``            ``linear_summation``  True
    ``rectify_rate``                ``rectify_rate``      0.0
    ``rectify_output``              ``rectify_output``    False
    =============================== ===================== =========

    Attributes
    ----------
    rate : brainstate.ShortTermState
        Current rate state :math:`X_n` (float64 array of shape ``self.varshape``
        or ``(batch_size,) + self.varshape``).
    noise : brainstate.ShortTermState
        Last noise sample :math:`\sigma\,\xi_{n-1}` (float64 array, same shape as
        ``rate``).
    instant_rate : brainstate.ShortTermState
        Rate value after instantaneous event application (float64 array, same
        shape as ``rate``).
    delayed_rate : brainstate.ShortTermState
        Rate value before current update, used for delayed projections (float64
        array, same shape as ``rate``).
    _step_count : brainstate.ShortTermState
        Internal step counter for delayed event scheduling (int64 scalar).
    _delayed_ex_queue : dict
        Internal queue mapping ``step_idx`` to accumulated excitatory delayed
        events.
    _delayed_in_queue : dict
        Internal queue mapping ``step_idx`` to accumulated inhibitory delayed
        events.

    Raises
    ------
    ValueError
        If ``tau <= 0``, ``lambda_ < 0``, ``sigma < 0``, or
        ``rectify_rate < 0``.
    ValueError
        If ``instant_rate_events`` contain non-zero ``delay_steps``.
    ValueError
        If ``delayed_rate_events`` contain negative ``delay_steps``.
    ValueError
        If event tuples have length other than 2, 3, or 4.

    Notes
    -----
    **Runtime Event Semantics**

    - ``instant_rate_events``: Applied in the current step without delay. Each
      event can be:

      - A scalar (treated as ``rate`` value with ``weight=1.0``).
      - A tuple ``(rate, weight)`` or ``(rate, weight, delay_steps)`` or
        ``(rate, weight, delay_steps, multiplicity)``.
      - A dict with keys ``'rate'``/``'coeff'``/``'value'``, ``'weight'``,
        ``'delay_steps'``/``'delay'``, ``'multiplicity'``.

    - ``delayed_rate_events``: Scheduled with integer ``delay_steps`` (units of
      simulation time step). Same format as ``instant_rate_events``.

    - Sign convention: events with ``weight >= 0`` contribute to the excitatory
      branch; events with ``weight < 0`` contribute to the inhibitory branch.

    - For ``linear_summation=False``, event values are transformed by the
      threshold-linear gain during buffering (matching NEST event handlers).

    **Comparison to Other Rate Neuron Variants**

    - ``rate_neuron_ipn``: Uses linear or custom gain function with optional
      multiplicative coupling. ``threshold_lin_rate_ipn`` is a special case with
      fixed threshold-linear gain and no multiplicative coupling.
    - ``threshold_lin_rate_opn``: Output-noise variant (noise applied after
      nonlinearity) vs. input noise (applied before dynamics propagation).

    **Failure Modes**


    - No automatic failure handling. Negative time constants, decay rates, or
      noise parameters are caught at construction by ``_validate_parameters``.
    - Invalid event formats raise ``ValueError`` during update.
    - Numerical instability is unlikely due to exact OU integration and stable
      clipping operations, but extreme parameter combinations (very large
      :math:`\sigma`, very small :math:`\tau`) may lead to rate explosions
      without ``rectify_output=True``.

    Examples
    --------
    **Example 1**: Minimal threshold-linear rate neuron with external drive.

    .. code-block:: python

       >>> import brainpy.state as bst
       >>> import saiunit as u
       >>> model = bst.threshold_lin_rate_ipn(
       ...     in_size=10, tau=20*u.ms, sigma=0.5, g=2.0, theta=1.0
       ... )
       >>> model.init_all_states(batch_size=1)
       >>> rate = model(x=0.5)  # external drive
       >>> print(rate.shape)
       (1, 10)

    **Example 2**: Saturating threshold-linear neuron with rectification.

    .. code-block:: python

       >>> model = bst.threshold_lin_rate_ipn(
       ...     in_size=5,
       ...     tau=10*u.ms,
       ...     lambda_=2.0,
       ...     g=1.0, theta=0.5, alpha=5.0,
       ...     rectify_rate=0.0, rectify_output=True
       ... )
       >>> model.init_all_states()

    **Example 3**: Update with instantaneous and delayed events.

    .. code-block:: python

       >>> model = bst.threshold_lin_rate_ipn(in_size=3, tau=10*u.ms, sigma=0.1)
       >>> model.init_all_states()
       >>> instant_event = {'rate': 2.0, 'weight': 0.1}
       >>> delayed_event = {'rate': 1.5, 'weight': -0.05, 'delay_steps': 3}
       >>> rate = model.update(
       ...     x=0.2,
       ...     instant_rate_events=instant_event,
       ...     delayed_rate_events=delayed_event
       ... )

    References
    ----------
    .. [1] NEST Simulator Documentation: ``threshold_lin_rate_ipn``
           https://nest-simulator.readthedocs.io/en/stable/models/rate_neuron_ipn.html
    .. [2] NEST Simulator Documentation: ``threshold_lin_rate`` nonlinearity
           https://nest-simulator.readthedocs.io/en/stable/models/rate_transformer_node.html
    .. [3] Hahne, J., Dahmen, D., Schuecker, J., Frommer, A., Bolten, M.,
           Helias, M., & Diesmann, M. (2017). Integration of continuous-time
           dynamics in a spiking neural network simulator.
           *Frontiers in Neuroinformatics*, 11, 34.
           https://doi.org/10.3389/fninf.2017.00034

    See Also
    --------
    threshold_lin_rate_opn : Output-noise variant of the threshold-linear rate neuron.
    rate_neuron_ipn : General input-noise rate neuron with custom gain functions.
    lin_rate : Deterministic linear rate neuron (``sigma=0``, no threshold).
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        tau: ArrayLike = 10.0 * u.ms,
        lambda_: ArrayLike = 1.0,
        sigma: ArrayLike = 1.0,
        mu: ArrayLike = 0.0,
        g: ArrayLike = 1.0,
        theta: ArrayLike = 0.0,
        alpha: ArrayLike = np.inf,
        mult_coupling: bool = False,
        linear_summation: bool = True,
        rectify_rate: ArrayLike = 0.0,
        rectify_output: bool = False,
        rate_initializer: Callable = braintools.init.Constant(0.0),
        noise_initializer: Callable = braintools.init.Constant(0.0),
        name: str = None,
    ):
        super().__init__(
            in_size=in_size,
            tau=tau,
            sigma=sigma,
            mu=mu,
            g=g,
            mult_coupling=mult_coupling,
            g_ex=1.0,
            g_in=1.0,
            theta_ex=0.0,
            theta_in=0.0,
            linear_summation=linear_summation,
            rate_initializer=rate_initializer,
            noise_initializer=noise_initializer,
            name=name,
        )
        self.theta = braintools.init.param(theta, self.varshape)
        self.alpha = braintools.init.param(alpha, self.varshape)
        self.lambda_ = braintools.init.param(lambda_, self.varshape)
        self.rectify_rate = braintools.init.param(rectify_rate, self.varshape)
        self.rectify_output = bool(rectify_output)
        self._validate_parameters()

    @property
    def recordables(self):
        r"""List of state variable names that can be recorded during simulation.

        Returns
        -------
        list of str
            ``['rate', 'noise']``. The ``rate`` variable records the current rate
            state :math:`X_n`, and ``noise`` records the last noise sample
            :math:`\sigma\,\xi_{n-1}`.

        Notes
        -----
        These variables can be accessed via recording tools in BrainPy for
        post-simulation analysis of rate dynamics and noise contributions.
        """
        return ['rate', 'noise']

    @property
    def receptor_types(self):
        r"""Receptor type dictionary for projection compatibility.

        Returns
        -------
        dict[str, int]
            ``{'RATE': 0}``. Rate neurons have a single unified receptor port
            for all rate-based inputs. Excitatory vs. inhibitory separation is
            handled internally via event weight signs.

        Notes
        -----
        This property is used by projection objects to validate connection targets.
        Unlike spiking neurons with separate AMPA/GABA receptor ports, rate neurons
        use sign-based branch routing (``weight >= 0`` → excitatory branch,
        ``weight < 0`` → inhibitory branch).
        """
        return {'RATE': 0}

    def _validate_parameters(self):
        r"""Validate model parameters at construction time.

        Raises
        ------
        ValueError
            If ``tau <= 0``, ``lambda_ < 0``, ``sigma < 0``, or
            ``rectify_rate < 0``.

        Notes
        -----
        This method is called automatically during ``__init__``.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.tau, self.sigma)):
            return
        if np.any(self.tau <= 0.0 * u.ms):
            raise ValueError('Time constant tau must be > 0.')
        if np.any(self.lambda_ < 0.0):
            raise ValueError('Passive decay rate lambda must be >= 0.')
        if np.any(self.sigma < 0.0):
            raise ValueError('Noise parameter sigma must be >= 0.')
        if np.any(self.rectify_rate < 0.0):
            raise ValueError('Rectifying rate must be >= 0.')

[docs] def init_state(self, **kwargs): r"""Initialize all state variables for simulation. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- This method initializes: - ``rate``: Current rate state :math:`X_n`. - ``noise``: Last noise sample :math:`\sigma\,\xi_{n-1}`. - ``instant_rate``: Rate after instantaneous event application. - ``delayed_rate``: Rate before current update (for delayed projections). - ``_step_count``: Internal step counter for delay scheduling. - ``_delayed_ex_queue``, ``_delayed_in_queue``: Delay queues. All state arrays are initialized as float64 NumPy arrays using the provided initializers. """ rate = braintools.init.param(self.rate_initializer, self.varshape) noise = braintools.init.param(self.noise_initializer, self.varshape) rate_np = self._to_numpy(rate) noise_np = self._to_numpy(noise) self.rate = brainstate.ShortTermState(rate_np) self.noise = brainstate.ShortTermState(noise_np) dftype = brainstate.environ.dftype() self.instant_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True)) self.delayed_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True)) self._step_count = 0 self._delayed_ex_queue = {} self._delayed_in_queue = {}
[docs] def update(self, x=0.0, instant_rate_events=None, delayed_rate_events=None, noise=None): r"""Perform one simulation step of stochastic threshold-linear rate dynamics. Parameters ---------- x : ArrayLike, optional External drive (scalar or array broadcastable to ``self.varshape``). Added to ``mu`` as constant forcing. Default is ``0.0``. instant_rate_events : None, dict, tuple, list, or iterable, optional Instantaneous rate events applied in the current step without delay. See class docstring for event format. Default is ``None``. delayed_rate_events : None, dict, tuple, list, or iterable, optional Delayed rate events scheduled with integer ``delay_steps`` (units of simulation time step). See class docstring for event format. Default is ``None``. noise : ArrayLike, optional Externally supplied noise sample :math:`\xi_n` (scalar or array broadcastable to state shape). If ``None`` (default), draws :math:`\xi_n\sim\mathcal{N}(0,1)` internally. Returns ------- rate_new : np.ndarray Updated rate state :math:`X_{n+1}` (float64 array of shape ``self.rate.value.shape``). Notes ----- **Update algorithm**: 1. Collect input contributions: - Delayed events arriving at current step (from internal queues). - Newly scheduled delayed events with ``delay_steps=0``. - Instantaneous events. - Delta inputs (sign-separated into excitatory/inhibitory). - Current inputs via ``sum_current_inputs(x, rate)``. 2. Compute propagator coefficients: For :math:`\lambda>0`: .. math:: P_1 = \exp(-\lambda h/\tau), \quad P_2 = (1-P_1)/\lambda, \quad N = \sigma\sqrt{(1-P_1^2)/(2\lambda)}. For :math:`\lambda=0`: :math:`P_1=1`, :math:`P_2=h/\tau`, :math:`N=\sigma\sqrt{h/\tau}`. 3. Propagate intrinsic dynamics: .. math:: X' = P_1 X_n + P_2(\mu + \mu_\mathrm{ext}) + N\,\xi_n. 4. Apply network input with threshold-linear gain: - ``linear_summation=True``: :math:`X' \gets X' + P_2\,\phi(I_\mathrm{ex}+I_\mathrm{in})`. - ``linear_summation=False``: :math:`X' \gets X' + P_2\,[\phi(I_\mathrm{ex})+\phi(I_\mathrm{in})]`. where :math:`\phi(h)=\min(\max(g(h-\theta),0),\alpha)`. 5. Apply optional output rectification: :math:`X_{n+1}\gets\max(X',\,\mathrm{rectify\_rate})`. 6. Update state variables: ``rate``, ``noise``, ``delayed_rate``, ``instant_rate``, ``_step_count``. **Numerical stability**: The threshold-linear gain uses ``np.minimum`` and ``np.maximum`` for stable clipping. The exponential Euler scheme uses ``np.expm1`` for numerically stable evaluation of :math:`1-e^{-x}` and handles the :math:`\lambda=0` limit explicitly. **Failure modes**: No automatic failure handling. Negative time constants, decay rates, or noise parameters are caught at construction by ``_validate_parameters``. Invalid event formats raise ``ValueError``. """ h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms)) dftype = brainstate.environ.dftype() state_shape = self.rate.value.shape tau, sigma, mu, g, theta, alpha = self._common_parameters_threshold(state_shape) lambda_ = self._broadcast_to_state(self._to_numpy(self.lambda_), state_shape) rectify_rate = self._broadcast_to_state(self._to_numpy(self.rectify_rate), state_shape) state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext = self._common_inputs_threshold( x=x, instant_rate_events=instant_rate_events, delayed_rate_events=delayed_rate_events, g=g, theta=theta, alpha=alpha, ) rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape) if noise is None: xi = jnp.asarray(np.random.normal(size=state_shape), dtype=dftype) else: xi = jnp.broadcast_to(jnp.asarray(noise, dtype=dftype), state_shape) noise_now = sigma * xi if np.any(lambda_ > 0.0): P1 = np.exp(-lambda_ * h / tau) P2 = -np.expm1(-lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_) input_noise_factor = np.sqrt( -0.5 * np.expm1(-2.0 * lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_) ) zero_lambda = lambda_ == 0.0 if np.any(zero_lambda): P1 = np.where(zero_lambda, 1.0, P1) P2 = np.where(zero_lambda, h / tau, P2) input_noise_factor = np.where(zero_lambda, np.sqrt(h / tau), input_noise_factor) else: P1 = np.ones_like(lambda_) P2 = h / tau input_noise_factor = np.sqrt(h / tau) mu_total = mu + mu_ext rate_new = P1 * rate_prev + P2 * mu_total + input_noise_factor * noise_now H_ex = jnp.ones_like(rate_prev) H_in = jnp.ones_like(rate_prev) if self.mult_coupling: H_ex = self._mult_coupling_ex(rate_prev) H_in = self._mult_coupling_in(rate_prev) if self.linear_summation: if self.mult_coupling: rate_new += P2 * H_ex * self._input(delayed_ex + instant_ex, g, theta, alpha) rate_new += P2 * H_in * self._input(delayed_in + instant_in, g, theta, alpha) else: rate_new += P2 * self._input(delayed_ex + instant_ex + delayed_in + instant_in, g, theta, alpha) else: # Nonlinear transform has already been applied per event in buffer handling. rate_new += P2 * H_ex * (delayed_ex + instant_ex) rate_new += P2 * H_in * (delayed_in + instant_in) if self.rectify_output: rate_new = jnp.where(rate_new < rectify_rate, rectify_rate, rate_new) self.rate.value = rate_new self.noise.value = noise_now self.delayed_rate.value = rate_prev self.instant_rate.value = rate_new self._step_count = step_idx + 1 return rate_new
class threshold_lin_rate_opn(_threshold_lin_rate_base): r"""NEST-compatible output-noise threshold-linear rate neuron. Implements the NEST ``threshold_lin_rate_opn`` model, an output-noise rate neuron with threshold-linear gain function. Unlike the input-noise variant (``threshold_lin_rate_ipn``), noise is applied to the output after deterministic dynamics, leading to different stationary distributions and noise scaling. Mathematical Description ------------------------ **1. Continuous-Time Deterministic Dynamics with Output Noise** The rate state :math:`X(t)` evolves according to the deterministic ODE: .. math:: \tau\frac{dX(t)}{dt} = -X(t) + \mu + I_\mathrm{net}(t), where: - :math:`\tau > 0` is the time constant (ms). - :math:`\mu` is the mean drive (dimensionless, external constant input). - :math:`I_\mathrm{net}(t)` is the network input (see below). The **output** rate :math:`X_\mathrm{noisy}(t)` is obtained by adding noise to the deterministic state: .. math:: X_\mathrm{noisy}(t) = X(t) + \sqrt{\frac{\tau}{h}}\,\sigma\,\xi(t), where: - :math:`\sigma \ge 0` is the output-noise scale (dimensionless). - :math:`\xi(t)\sim\mathcal{N}(0,1)` is standard Gaussian white noise. - :math:`h=dt` is the simulation time step (ms). The :math:`\sqrt{\tau/h}` scaling ensures the noise variance is independent of the time step for small :math:`h`. **2. Threshold-Linear Gain Function** The input nonlinearity :math:`\phi(h)` is identical to the input-noise variant: .. math:: \phi(h) = \min(\max(g(h-\theta), 0), \alpha), where: - :math:`g > 0` is the gain slope (dimensionless). - :math:`\theta` is the activation threshold (dimensionless). - :math:`\alpha > 0` is the saturation level (dimensionless). **3. Network Input Structure** The network input :math:`I_\mathrm{net}(t)` is computed according to: .. math:: I_\mathrm{net}(t) = \phi(I_\mathrm{ex}(t) + I_\mathrm{in}(t)) \quad\text{(if linear\_summation=True)}, or: .. math:: I_\mathrm{net}(t) = \phi(I_\mathrm{ex}(t)) + \phi(I_\mathrm{in}(t)) \quad\text{(if linear\_summation=False)}, where :math:`I_\mathrm{ex}(t)` and :math:`I_\mathrm{in}(t)` are excitatory and inhibitory branches (sign-separated by event weight). **Note**: Multiplicative coupling is **not** supported (``mult_coupling`` parameter is accepted for API compatibility but has no effect). **4. Discrete-Time Integration (Exponential Euler)** For time step :math:`h=dt` (in ms), the deterministic dynamics are integrated using exponential Euler: .. math:: X_{n+1} = P_1 X_n + P_2 (\mu + I_\mathrm{net,n}), where: .. math:: P_1 = \exp\left(-\frac{h}{\tau}\right), \quad P_2 = 1 - P_1 = -\mathrm{expm1}\left(-\frac{h}{\tau}\right). The noisy output is: .. math:: X_\mathrm{noisy,n} = X_n + \sqrt{\frac{\tau}{h}}\,\sigma\,\xi_n, where :math:`\xi_n\sim\mathcal{N}(0,1)`. **5. Update Ordering (Matching NEST ``rate_neuron_opn_impl.h``)** Per simulation step: 1. **Draw noise**: sample :math:`\xi_n\sim\mathcal{N}(0,1)`, compute :math:`\mathrm{noise}_n=\sigma\,\xi_n`. 2. **Build noisy output**: compute :math:`X_\mathrm{noisy,n}=X_n+\sqrt{\tau/h}\,\mathrm{noise}_n` and store as both ``delayed_rate`` and ``instant_rate`` (outgoing values for projections). 3. **Propagate deterministic dynamics**: apply exponential Euler to update :math:`X_n`. 4. **Read event buffers**: drain delayed events arriving at current step; accumulate instantaneous events. 5. **Apply network input with threshold-linear gain**: - ``linear_summation=True``: :math:`X_{n+1} \gets X_{n+1} + P_2\,\phi(I_\mathrm{ex}+I_\mathrm{in})`. - ``linear_summation=False``: :math:`X_{n+1} \gets X_{n+1} + P_2\,[\phi(I_\mathrm{ex})+\phi(I_\mathrm{in})]`. 6. **Update state variables**: ``rate``, ``noise``, ``noisy_rate``, ``delayed_rate``, ``instant_rate``, ``_step_count``. **Note**: Unlike input-noise variant, there is **no** rectification option for output-noise neurons. The noise is applied to the output only and does not affect the internal deterministic state. **6. Numerical Stability and Computational Complexity** - Construction enforces :math:`\tau>0`, :math:`\sigma\ge 0`. - The threshold-linear gain is evaluated using ``np.minimum`` and ``np.maximum`` for numerically stable clipping. - Per-call cost is :math:`O(\prod\mathrm{varshape})` with vectorized NumPy operations in float64. - The exponential Euler scheme is numerically stable for all :math:`h>0`. Parameters ---------- in_size : Size Population shape (tuple or int). All per-neuron parameters are broadcast to ``self.varshape``. tau : ArrayLike, optional Time constant :math:`\tau` (ms). Scalar or array broadcastable to ``self.varshape``. Must be :math:`>0`. Default: ``10.0 * u.ms``. sigma : ArrayLike, optional Output-noise scale :math:`\sigma` (dimensionless). Scalar or array broadcastable to ``self.varshape``. Must be :math:`\ge 0`. Default: ``1.0``. mu : ArrayLike, optional Mean drive :math:`\mu` (dimensionless). Scalar or array broadcastable to ``self.varshape``. External constant input to the rate dynamics. Default: ``0.0``. g : ArrayLike, optional Gain slope :math:`g` (dimensionless) for the threshold-linear function :math:`\phi(h)=\min(\max(g(h-\theta),0),\alpha)`. Scalar or array broadcastable to ``self.varshape``. Default: ``1.0``. theta : ArrayLike, optional Activation threshold :math:`\theta` (dimensionless). The gain function is zero for :math:`h<\theta`. Scalar or array broadcastable to ``self.varshape``. Default: ``0.0``. alpha : ArrayLike, optional Saturation level :math:`\alpha` (dimensionless). The gain function saturates at :math:`\alpha` for large inputs. Scalar or array broadcastable to ``self.varshape``. Default: ``np.inf`` (no saturation). mult_coupling : bool, optional API compatibility flag. Has **no effect** on dynamics for threshold-linear neurons (multiplicative coupling factors are constant 1.0). Default: ``False``. linear_summation : bool, optional Controls where the threshold-linear gain is applied. If ``True``, the gain is applied to the sum of excitatory and inhibitory inputs. If ``False``, the gain is applied separately to each input branch (matching NEST event semantics). Default: ``True``. rate_initializer : Callable, optional Initializer for the ``rate`` state variable :math:`X_0`. Callable compatible with ``braintools.init`` API. Default: ``braintools.init.Constant(0.0)``. noise_initializer : Callable, optional Initializer for the ``noise`` state variable (records last noise sample :math:`\sigma\,\xi_{n-1}`). Callable compatible with ``braintools.init`` API. Default: ``braintools.init.Constant(0.0)``. noisy_rate_initializer : Callable, optional Initializer for the ``noisy_rate`` state variable :math:`X_\mathrm{noisy,0}` (initial noisy output). Callable compatible with ``braintools.init`` API. Default: ``braintools.init.Constant(0.0)``. name : str or None, optional Module name for identification in hierarchies. If ``None``, an auto-generated name is used. Default: ``None``. Parameter Mapping ----------------- The following table maps NEST ``threshold_lin_rate_opn`` parameters to brainpy.state equivalents: =============================== ===================== ========= NEST Parameter brainpy.state Default =============================== ===================== ========= ``tau`` ``tau`` 10 ms ``sigma`` ``sigma`` 1.0 ``mu`` ``mu`` 0.0 ``g`` (gain slope) ``g`` 1.0 ``theta`` (threshold) ``theta`` 0.0 ``alpha`` (saturation) ``alpha`` inf ``mult_coupling`` ``mult_coupling`` False (no effect) ``linear_summation`` ``linear_summation`` True =============================== ===================== ========= **Note**: Unlike ``threshold_lin_rate_ipn``, this model does **not** have ``lambda`` (passive decay is fixed at 1.0), ``rectify_rate``, or ``rectify_output`` parameters. Attributes ---------- rate : brainstate.ShortTermState Current deterministic rate state :math:`X_n` (float64 array of shape ``self.varshape`` or ``(batch_size,) + self.varshape``). noise : brainstate.ShortTermState Last noise sample :math:`\sigma\,\xi_{n-1}` (float64 array, same shape as ``rate``). noisy_rate : brainstate.ShortTermState Noisy output rate :math:`X_\mathrm{noisy,n}=X_n+\sqrt{\tau/h}\,\sigma\,\xi_n` (float64 array, same shape as ``rate``). instant_rate : brainstate.ShortTermState Noisy rate value used for instantaneous projections (float64 array, same shape as ``rate``). delayed_rate : brainstate.ShortTermState Noisy rate value used for delayed projections (float64 array, same shape as ``rate``). _step_count : brainstate.ShortTermState Internal step counter for delayed event scheduling (int64 scalar). _delayed_ex_queue : dict Internal queue mapping ``step_idx`` to accumulated excitatory delayed events. _delayed_in_queue : dict Internal queue mapping ``step_idx`` to accumulated inhibitory delayed events. Raises ------ ValueError If ``tau <= 0`` or ``sigma < 0``. ValueError If ``instant_rate_events`` contain non-zero ``delay_steps``. ValueError If ``delayed_rate_events`` contain negative ``delay_steps``. ValueError If event tuples have length other than 2, 3, or 4. Notes ----- **Runtime Event Semantics** Event formats are identical to :class:`threshold_lin_rate_ipn`: - ``instant_rate_events``: Applied in the current step without delay. - ``delayed_rate_events``: Scheduled with integer ``delay_steps``. - Sign convention: ``weight >= 0`` → excitatory, ``weight < 0`` → inhibitory. **Comparison to Input-Noise Variant** The key differences between ``threshold_lin_rate_opn`` (output noise) and ``threshold_lin_rate_ipn`` (input noise) are: - **Noise location**: Output noise is added after nonlinearity; input noise is integrated before nonlinearity. - **Stationary distribution**: Output noise does not affect the mean of the deterministic attractor; input noise shifts the effective drive. - **Dynamics**: Output-noise model has simpler deterministic dynamics (:math:`\lambda=1.0` fixed) with additive output corruption. - **Rectification**: Input-noise variant supports ``rectify_output``; output- noise variant does not (noise is on output only). **Failure Modes** - No automatic failure handling. Negative time constants or noise parameters are caught at construction by ``_validate_parameters``. - Invalid event formats raise ``ValueError`` during update. - The noise scaling :math:`\sqrt{\tau/h}` can become large for small time steps, but this is by design to ensure correct variance scaling. Examples -------- **Example 1**: Minimal output-noise threshold-linear neuron. .. code-block:: python >>> import brainpy.state as bst >>> import saiunit as u >>> model = bst.threshold_lin_rate_opn( ... in_size=10, tau=20*u.ms, sigma=0.5, g=2.0, theta=1.0 ... ) >>> model.init_all_states(batch_size=1) >>> rate = model(x=0.5) # deterministic state >>> noisy_rate = model.noisy_rate.value # noisy output **Example 2**: Saturating threshold-linear neuron with output noise. .. code-block:: python >>> model = bst.threshold_lin_rate_opn( ... in_size=5, ... tau=10*u.ms, ... sigma=0.2, ... g=1.5, theta=0.5, alpha=5.0 ... ) >>> model.init_all_states() **Example 3**: Update with events (identical to input-noise variant). .. code-block:: python >>> model = bst.threshold_lin_rate_opn(in_size=3, tau=10*u.ms, sigma=0.1) >>> model.init_all_states() >>> instant_event = {'rate': 2.0, 'weight': 0.1} >>> delayed_event = {'rate': 1.5, 'weight': -0.05, 'delay_steps': 3} >>> rate = model.update( ... x=0.2, ... instant_rate_events=instant_event, ... delayed_rate_events=delayed_event ... ) References ---------- .. [1] NEST Simulator Documentation: ``threshold_lin_rate_opn`` https://nest-simulator.readthedocs.io/en/stable/models/rate_neuron_opn.html .. [2] NEST Simulator Documentation: ``threshold_lin_rate`` nonlinearity https://nest-simulator.readthedocs.io/en/stable/models/rate_transformer_node.html .. [3] Hahne, J., Dahmen, D., Schuecker, J., Frommer, A., Bolten, M., Helias, M., & Diesmann, M. (2017). Integration of continuous-time dynamics in a spiking neural network simulator. *Frontiers in Neuroinformatics*, 11, 34. https://doi.org/10.3389/fninf.2017.00034 See Also -------- threshold_lin_rate_ipn : Input-noise variant of the threshold-linear rate neuron. rate_neuron_opn : General output-noise rate neuron with custom gain functions. lin_rate : Deterministic linear rate neuron (``sigma=0``, no threshold). """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, tau: ArrayLike = 10.0 * u.ms, sigma: ArrayLike = 1.0, mu: ArrayLike = 0.0, g: ArrayLike = 1.0, theta: ArrayLike = 0.0, alpha: ArrayLike = np.inf, mult_coupling: bool = False, linear_summation: bool = True, rate_initializer: Callable = braintools.init.Constant(0.0), noise_initializer: Callable = braintools.init.Constant(0.0), noisy_rate_initializer: Callable = braintools.init.Constant(0.0), name: str = None, ): super().__init__( in_size=in_size, tau=tau, sigma=sigma, mu=mu, g=g, mult_coupling=mult_coupling, g_ex=1.0, g_in=1.0, theta_ex=0.0, theta_in=0.0, linear_summation=linear_summation, rate_initializer=rate_initializer, noise_initializer=noise_initializer, name=name, ) self.theta = braintools.init.param(theta, self.varshape) self.alpha = braintools.init.param(alpha, self.varshape) self.noisy_rate_initializer = noisy_rate_initializer self._validate_parameters() @property def recordables(self): r"""List of state variable names that can be recorded during simulation. Returns ------- list of str ``['rate', 'noise', 'noisy_rate']``. The ``rate`` variable records the deterministic rate state :math:`X_n`, ``noise`` records the last noise sample :math:`\sigma\,\xi_{n-1}`, and ``noisy_rate`` records the noisy output :math:`X_\mathrm{noisy,n}`. Notes ----- These variables can be accessed via recording tools in BrainPy for post-simulation analysis. The ``noisy_rate`` is the value transmitted to downstream neurons via projections. """ return ['rate', 'noise', 'noisy_rate'] @property def receptor_types(self): r"""Receptor type dictionary for projection compatibility. Returns ------- dict[str, int] ``{'RATE': 0}``. Rate neurons have a single unified receptor port for all rate-based inputs. Excitatory vs. inhibitory separation is handled internally via event weight signs. Notes ----- This property is used by projection objects to validate connection targets. Unlike spiking neurons with separate AMPA/GABA receptor ports, rate neurons use sign-based branch routing (``weight >= 0`` → excitatory branch, ``weight < 0`` → inhibitory branch). """ return {'RATE': 0} def _validate_parameters(self): r"""Validate model parameters at construction time. Raises ------ ValueError If ``tau <= 0`` or ``sigma < 0``. Notes ----- This method is called automatically during ``__init__``. Unlike the input- noise variant, this model does not have ``lambda_`` or ``rectify_rate`` parameters to validate. """ # Skip validation when parameters are JAX tracers (e.g. during jit). if any(is_tracer(v) for v in (self.tau, self.sigma)): return if np.any(self.tau <= 0.0 * u.ms): raise ValueError('Time constant tau must be > 0.') if np.any(self.sigma < 0.0): raise ValueError('Noise parameter sigma must be >= 0.')
[docs] def init_state(self, **kwargs): r"""Initialize all state variables for simulation. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- This method initializes: - ``rate``: Deterministic rate state :math:`X_n`. - ``noise``: Last noise sample :math:`\sigma\,\xi_{n-1}`. - ``noisy_rate``: Noisy output :math:`X_\mathrm{noisy,n}`. - ``instant_rate``: Noisy rate for instantaneous projections. - ``delayed_rate``: Noisy rate for delayed projections. - ``_step_count``: Internal step counter for delay scheduling. - ``_delayed_ex_queue``, ``_delayed_in_queue``: Delay queues. All state arrays are initialized as float64 NumPy arrays using the provided initializers. Both ``instant_rate`` and ``delayed_rate`` are initialized to ``noisy_rate`` (outgoing values are noisy). """ rate = braintools.init.param(self.rate_initializer, self.varshape) noise = braintools.init.param(self.noise_initializer, self.varshape) noisy_rate = braintools.init.param(self.noisy_rate_initializer, self.varshape) rate_np = self._to_numpy(rate) noise_np = self._to_numpy(noise) noisy_rate_np = self._to_numpy(noisy_rate) dftype = brainstate.environ.dftype() self.rate = brainstate.ShortTermState(rate_np) self.noise = brainstate.ShortTermState(noise_np) self.noisy_rate = brainstate.ShortTermState(noisy_rate_np) self.instant_rate = brainstate.ShortTermState(np.array(noisy_rate_np, dtype=dftype, copy=True)) self.delayed_rate = brainstate.ShortTermState(np.array(noisy_rate_np, dtype=dftype, copy=True)) self._step_count = 0 self._delayed_ex_queue = {} self._delayed_in_queue = {}
[docs] def update(self, x=0.0, instant_rate_events=None, delayed_rate_events=None, noise=None): r"""Perform one simulation step of deterministic threshold-linear rate dynamics with output noise. Parameters ---------- x : ArrayLike, optional External drive (scalar or array broadcastable to ``self.varshape``). Added to ``mu`` as constant forcing. Default is ``0.0``. instant_rate_events : None, dict, tuple, list, or iterable, optional Instantaneous rate events applied in the current step without delay. See class docstring for event format. Default is ``None``. delayed_rate_events : None, dict, tuple, list, or iterable, optional Delayed rate events scheduled with integer ``delay_steps`` (units of simulation time step). See class docstring for event format. Default is ``None``. noise : ArrayLike, optional Externally supplied noise sample :math:`\xi_n` (scalar or array broadcastable to state shape). If ``None`` (default), draws :math:`\xi_n\sim\mathcal{N}(0,1)` internally. Returns ------- rate_new : np.ndarray Updated deterministic rate state :math:`X_{n+1}` (float64 array of shape ``self.rate.value.shape``). Notes ----- **Update algorithm**: 1. **Draw noise and compute noisy output**: .. math:: \mathrm{noise}_n = \sigma\,\xi_n, \quad X_\mathrm{noisy,n} = X_n + \sqrt{\frac{\tau}{h}}\,\mathrm{noise}_n. Store :math:`X_\mathrm{noisy,n}` as ``delayed_rate`` and ``instant_rate`` (outgoing values for projections). 2. **Collect input contributions**: - Delayed events arriving at current step (from internal queues). - Newly scheduled delayed events with ``delay_steps=0``. - Instantaneous events. - Delta inputs (sign-separated into excitatory/inhibitory). - Current inputs via ``sum_current_inputs(x, rate)``. 3. **Compute propagator coefficients** (deterministic exponential Euler): .. math:: P_1 = \exp(-h/\tau), \quad P_2 = 1 - P_1 = -\mathrm{expm1}(-h/\tau). 4. **Propagate deterministic dynamics**: .. math:: X_{n+1} = P_1 X_n + P_2(\mu + \mu_\mathrm{ext}). 5. **Apply network input with threshold-linear gain**: - ``linear_summation=True``: :math:`X_{n+1} \gets X_{n+1} + P_2\,\phi(I_\mathrm{ex}+I_\mathrm{in})`. - ``linear_summation=False``: :math:`X_{n+1} \gets X_{n+1} + P_2\,[\phi(I_\mathrm{ex})+\phi(I_\mathrm{in})]`. where :math:`\phi(h)=\min(\max(g(h-\theta),0),\alpha)`. 6. **Update state variables**: ``rate``, ``noise``, ``noisy_rate``, ``delayed_rate``, ``instant_rate``, ``_step_count``. **Key difference from input-noise variant**: Noise is added to the output *before* the deterministic update, not during the stochastic integration. This means the internal state :math:`X_n` evolves deterministically, and only the transmitted rate is noisy. **Numerical stability**: The threshold-linear gain uses ``np.minimum`` and ``np.maximum`` for stable clipping. The exponential Euler scheme uses ``np.expm1`` for numerically stable evaluation of :math:`1-e^{-x}`. The noise scaling :math:`\sqrt{\tau/h}` ensures correct variance scaling as :math:`h\to 0`. **Failure modes**: No automatic failure handling. Negative time constants or noise parameters are caught at construction by ``_validate_parameters``. Invalid event formats raise ``ValueError``. """ h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms)) dftype = brainstate.environ.dftype() state_shape = self.rate.value.shape tau, sigma, mu, g, theta, alpha = self._common_parameters_threshold(state_shape) state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext = self._common_inputs_threshold( x=x, instant_rate_events=instant_rate_events, delayed_rate_events=delayed_rate_events, g=g, theta=theta, alpha=alpha, ) rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape) if noise is None: xi = jnp.asarray(np.random.normal(size=state_shape), dtype=dftype) else: xi = jnp.broadcast_to(jnp.asarray(noise, dtype=dftype), state_shape) noise_now = sigma * xi P1 = np.exp(-h / tau) P2 = -np.expm1(-h / tau) output_noise_factor = np.sqrt(tau / h) noisy_rate = rate_prev + output_noise_factor * noise_now mu_total = mu + mu_ext rate_new = P1 * rate_prev + P2 * mu_total H_ex = jnp.ones_like(rate_prev) H_in = jnp.ones_like(rate_prev) if self.mult_coupling: H_ex = self._mult_coupling_ex(noisy_rate) H_in = self._mult_coupling_in(noisy_rate) if self.linear_summation: if self.mult_coupling: rate_new += P2 * H_ex * self._input(delayed_ex + instant_ex, g, theta, alpha) rate_new += P2 * H_in * self._input(delayed_in + instant_in, g, theta, alpha) else: rate_new += P2 * self._input(delayed_ex + instant_ex + delayed_in + instant_in, g, theta, alpha) else: # Nonlinear transform has already been applied per event in buffer handling. rate_new += P2 * H_ex * (delayed_ex + instant_ex) rate_new += P2 * H_in * (delayed_in + instant_in) self.rate.value = rate_new self.noise.value = noise_now self.noisy_rate.value = noisy_rate self.delayed_rate.value = noisy_rate self.instant_rate.value = noisy_rate self._step_count = step_idx + 1 return rate_new