Source code for brainpy_state._nest.sigmoid_rate_gg_1998

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


from typing import Callable

import brainstate
import braintools
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from brainpy_state._nest.lin_rate import _lin_rate_base
from ._utils import is_tracer

__all__ = [
    'sigmoid_rate_gg_1998_ipn',
]


class _sigmoid_rate_gg_1998_base(_lin_rate_base):
    r"""Base class for Gancarz-Grossberg (1998) quartic gain function rate neurons.

    Provides the quartic gain function :math:`\phi(h) = (gh)^4 / (0.1^4 + (gh)^4)`
    and event/input processing infrastructure specific to ``sigmoid_rate_gg_1998``
    variants. Inherits common rate neuron machinery from ``_lin_rate_base``.
    """
    __module__ = 'brainpy.state'

    def _input(self, h, g):
        r"""Evaluate Gancarz-Grossberg quartic gain function.

        Computes :math:`\phi(h) = (g h)^4 / (0.1^4 + (g h)^4)`, providing a steep
        saturating nonlinearity with half-activation at :math:`h \approx 0.1/g`.

        Parameters
        ----------
        h : float or ndarray
            Input activation (dimensionless), typically summed synaptic drive.
        g : float or ndarray
            Gain parameter (dimensionless). Larger values steepen the curve.

        Returns
        -------
        phi : ndarray
            Transformed activation, same shape as broadcasted ``h`` and ``g``.
            Range :math:`[0, 1)` (asymptotically approaches 1).
        """
        gh4 = np.power(g * h, 4.0)
        return gh4 / (np.power(0.1, 4.0) + gh4)

    @staticmethod
    def _mult_coupling_ex(rate):
        r"""Compute excitatory multiplicative coupling factor (fixed to 1.0).

        For ``sigmoid_rate_gg_1998`` variants, multiplicative coupling is not
        implemented; this method always returns ones regardless of ``rate``.

        Parameters
        ----------
        rate : ndarray
            Current rate (used to determine output shape).

        Returns
        -------
        ndarray
            Ones array with shape matching ``rate``, dtype float64.
        """
        dftype = brainstate.environ.dftype()
        return jnp.ones_like(rate, dtype=dftype)

    @staticmethod
    def _mult_coupling_in(rate):
        r"""Compute inhibitory multiplicative coupling factor (fixed to 1.0).

        For ``sigmoid_rate_gg_1998`` variants, multiplicative coupling is not
        implemented; this method always returns ones regardless of ``rate``.

        Parameters
        ----------
        rate : ndarray
            Current rate (used to determine output shape).

        Returns
        -------
        ndarray
            Ones array with shape matching ``rate``, dtype float64.
        """
        dftype = brainstate.environ.dftype()
        return jnp.ones_like(rate, dtype=dftype)

    def _extract_event_fields(self, ev, default_delay_steps: int):
        r"""Parse rate event into (rate, weight, multiplicity, delay_steps).

        Accepts dict, tuple/list, or scalar formats and applies defaults for
        missing fields.

        Parameters
        ----------
        ev : dict or tuple or list or scalar
            Event specification. Supported formats:
            - Dict: ``{'rate': r, 'weight': w, 'delay_steps': d, 'multiplicity': m}``
              (all fields optional; aliases ``'coeff'``/``'value'`` for ``'rate'``,
              ``'delay'`` for ``'delay_steps'``).
            - Tuple/list: ``(rate, weight)``, ``(rate, weight, delay_steps)``, or
              ``(rate, weight, delay_steps, multiplicity)``.
            - Scalar: Interpreted as ``rate`` with default weight=1, multiplicity=1,
              delay_steps=``default_delay_steps``.
        default_delay_steps : int
            Delay value to use if not specified in event. Typically 0 for
            instantaneous events, 1 for delayed events.

        Returns
        -------
        rate : float or array
            Event rate value.
        weight : float or array
            Synaptic weight (positive for excitatory, negative for inhibitory).
        multiplicity : float or array
            Event multiplicity factor.
        delay_steps : int
            Delay in simulation steps (validated as non-negative integer).

        Raises
        ------
        ValueError
            If tuple/list length is not 2, 3, or 4.
        """
        if isinstance(ev, dict):
            rate = ev.get('rate', ev.get('coeff', ev.get('value', 0.0)))
            weight = ev.get('weight', 1.0)
            multiplicity = ev.get('multiplicity', 1.0)
            delay_steps = ev.get('delay_steps', ev.get('delay', default_delay_steps))
        elif isinstance(ev, (tuple, list)):
            if len(ev) == 2:
                rate, weight = ev
                delay_steps = default_delay_steps
                multiplicity = 1.0
            elif len(ev) == 3:
                rate, weight, delay_steps = ev
                multiplicity = 1.0
            elif len(ev) == 4:
                rate, weight, delay_steps, multiplicity = ev
            else:
                raise ValueError('Rate event tuples must have length 2, 3, or 4.')
        else:
            rate = ev
            weight = 1.0
            multiplicity = 1.0
            delay_steps = default_delay_steps

        delay_steps = self._to_int_scalar(delay_steps, name='delay_steps')
        return rate, weight, multiplicity, delay_steps

    def _event_to_ex_in(self, ev, default_delay_steps: int, state_shape, g):
        r"""Convert single rate event to excitatory/inhibitory contributions.

        Parses event, broadcasts fields to ``state_shape``, applies gain function
        (if ``linear_summation=False``), and splits by weight polarity.

        Parameters
        ----------
        ev : dict or tuple or list or scalar
            Rate event in any supported format (see ``_extract_event_fields``).
        default_delay_steps : int
            Default delay if not specified in event.
        state_shape : tuple of int
            Target broadcast shape (``rate.value.shape``).
        g : ndarray
            Gain parameter broadcast to ``state_shape``.

        Returns
        -------
        ex : ndarray
            Excitatory contribution (non-zero where ``weight >= 0``).
        inh : ndarray
            Inhibitory contribution (non-zero where ``weight < 0``).
        delay_steps : int
            Parsed delay in steps.

        Notes
        -----
        When ``linear_summation=True``, returns raw ``rate * weight * multiplicity``
        (gain applied later to branch sums). When ``linear_summation=False``, applies
        :math:`\phi(rate)` before weighting (per-event nonlinearity).
        """
        rate, weight, multiplicity, delay_steps = self._extract_event_fields(ev, default_delay_steps)

        rate_np = self._broadcast_to_state(self._to_numpy(rate), state_shape)
        weight_np = self._broadcast_to_state(self._to_numpy(weight), state_shape)
        multiplicity_np = self._broadcast_to_state(self._to_numpy(multiplicity), state_shape)
        dftype = brainstate.environ.dftype()
        weight_sign = self._broadcast_to_state(
            np.asarray(u.get_mantissa(weight), dtype=dftype) >= 0.0,
            state_shape,
        )

        if self.linear_summation:
            weighted_value = rate_np * weight_np * multiplicity_np
        else:
            weighted_value = self._input(rate_np, g) * weight_np * multiplicity_np

        ex = np.where(weight_sign, weighted_value, 0.0)
        inh = np.where(weight_sign, 0.0, weighted_value)
        return ex, inh, delay_steps

    def _accumulate_instant_events_sigmoid_gg_1998(self, events, state_shape, g):
        r"""Accumulate instantaneous rate events into excitatory/inhibitory sums.

        Processes all events with ``delay_steps=0`` and aggregates contributions.
        Raises error if any event specifies non-zero delay.

        Parameters
        ----------
        events : None or list or nested structure
            Instantaneous events (flattened by ``_coerce_events``).
        state_shape : tuple of int
            Target array shape.
        g : ndarray
            Gain parameter broadcast to ``state_shape``.

        Returns
        -------
        ex : ndarray
            Total excitatory contribution (shape ``state_shape``).
        inh : ndarray
            Total inhibitory contribution (shape ``state_shape``).

        Raises
        ------
        ValueError
            If any event has non-zero ``delay_steps``.
        """
        dftype = brainstate.environ.dftype()
        ex = np.zeros(state_shape, dtype=dftype)
        inh = np.zeros(state_shape, dtype=dftype)
        for ev in self._coerce_events(events):
            ex_i, inh_i, delay_steps = self._event_to_ex_in(
                ev,
                default_delay_steps=0,
                state_shape=state_shape,
                g=g,
            )
            if delay_steps != 0:
                raise ValueError('instant_rate_events must not specify non-zero delay_steps.')
            ex += ex_i
            inh += inh_i
        return ex, inh

    def _schedule_delayed_events_sigmoid_gg_1998(self, events, step_idx: int, state_shape, g):
        r"""Schedule delayed rate events and return zero-delay contributions.

        For each event:
        - If ``delay_steps=0``, add to immediate excitatory/inhibitory sums.
        - If ``delay_steps > 0``, queue for delivery at ``step_idx + delay_steps``.

        Parameters
        ----------
        events : None or list or nested structure
            Delayed events (flattened by ``_coerce_events``).
        step_idx : int
            Current simulation step index.
        state_shape : tuple of int
            Target array shape.
        g : ndarray
            Gain parameter broadcast to ``state_shape``.

        Returns
        -------
        ex_now : ndarray
            Excitatory contributions with zero delay (shape ``state_shape``).
        inh_now : ndarray
            Inhibitory contributions with zero delay (shape ``state_shape``).

        Raises
        ------
        ValueError
            If any event has negative ``delay_steps``.

        Notes
        -----
        Modifies ``self._delayed_ex_queue`` and ``self._delayed_in_queue`` in-place
        by adding entries at ``target_step = step_idx + delay_steps``. Multiple
        events targeting the same step are summed via ``_queue_add``.
        """
        dftype = brainstate.environ.dftype()
        ex_now = np.zeros(state_shape, dtype=dftype)
        inh_now = np.zeros(state_shape, dtype=dftype)
        for ev in self._coerce_events(events):
            ex_i, inh_i, delay_steps = self._event_to_ex_in(
                ev,
                default_delay_steps=1,
                state_shape=state_shape,
                g=g,
            )
            if delay_steps < 0:
                raise ValueError('delay_steps for delayed_rate_events must be >= 0.')
            if delay_steps == 0:
                ex_now += ex_i
                inh_now += inh_i
            else:
                target_step = step_idx + delay_steps
                self._queue_add(self._delayed_ex_queue, target_step, ex_i)
                self._queue_add(self._delayed_in_queue, target_step, inh_i)
        return ex_now, inh_now

    def _common_inputs_sigmoid_gg_1998(self, x, instant_rate_events, delayed_rate_events, g):
        r"""Aggregate all input sources into excitatory/inhibitory and external drive.

        Drains delayed queues, schedules new delayed events, accumulates instantaneous
        events, separates delta inputs by polarity, and computes total external drive.

        Parameters
        ----------
        x : float or array_like
            External continuous current input.
        instant_rate_events : None or list or nested structure
            Instantaneous events (``delay_steps=0``).
        delayed_rate_events : None or list or nested structure
            Delayed events (``delay_steps >= 0``, default 1).
        g : ndarray
            Gain parameter broadcast to ``state_shape``.

        Returns
        -------
        state_shape : tuple of int
            Shape of ``rate.value``.
        step_idx : int
            Current step index (before increment).
        delayed_ex : ndarray
            Total delayed excitatory input (shape ``state_shape``).
        delayed_in : ndarray
            Total delayed inhibitory input (shape ``state_shape``).
        instant_ex : ndarray
            Total instantaneous excitatory input (includes positive delta inputs).
        instant_in : ndarray
            Total instantaneous inhibitory input (includes negative delta inputs).
        mu_ext : ndarray
            External drive from ``x`` and projection currents (via ``sum_current_inputs``).

        Notes
        -----
        Delta inputs (from projections calling ``add_delta_input``) are split by
        sign: positive values added to ``instant_ex``, negative to ``instant_in``.
        Continuous inputs (via ``add_current_input``) are aggregated into ``mu_ext``.
        """
        state_shape = self.rate.value.shape
        step_idx = self._step_count

        delayed_ex, delayed_in = self._drain_delayed_queue(step_idx, state_shape)
        delayed_ex_now, delayed_in_now = self._schedule_delayed_events_sigmoid_gg_1998(
            delayed_rate_events,
            step_idx=step_idx,
            state_shape=state_shape,
            g=g,
        )
        delayed_ex = delayed_ex + delayed_ex_now
        delayed_in = delayed_in + delayed_in_now

        instant_ex, instant_in = self._accumulate_instant_events_sigmoid_gg_1998(
            instant_rate_events,
            state_shape=state_shape,
            g=g,
        )

        delta_input = self._broadcast_to_state(self._to_numpy(self.sum_delta_inputs(0.0)), state_shape)
        instant_ex += np.where(delta_input > 0.0, delta_input, 0.0)
        instant_in += np.where(delta_input < 0.0, delta_input, 0.0)

        mu_ext = self._broadcast_to_state(self._to_numpy(self.sum_current_inputs(x, self.rate.value)), state_shape)

        return state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext

    def _common_parameters_sigmoid_gg_1998(self, state_shape):
        r"""Broadcast model parameters to state shape.

        Converts time constant to milliseconds and broadcasts all parameters to
        match the current state array shape.

        Parameters
        ----------
        state_shape : tuple of int
            Target broadcast shape (``rate.value.shape``).

        Returns
        -------
        tau : ndarray
            Time constant in milliseconds (shape ``state_shape``).
        sigma : ndarray
            Noise amplitude (shape ``state_shape``).
        mu : ndarray
            Constant drive (shape ``state_shape``).
        g : ndarray
            Gain parameter (shape ``state_shape``).
        """
        tau = self._broadcast_to_state(self._to_numpy_ms(self.tau), state_shape)
        sigma = self._broadcast_to_state(self._to_numpy(self.sigma), state_shape)
        mu = self._broadcast_to_state(self._to_numpy(self.mu), state_shape)
        g = self._broadcast_to_state(self._to_numpy(self.g), state_shape)
        return tau, sigma, mu, g


class sigmoid_rate_gg_1998_ipn(_sigmoid_rate_gg_1998_base):
    r"""NEST-compatible ``sigmoid_rate_gg_1998_ipn`` nonlinear rate neuron with input noise.

    Description
    -----------

    ``sigmoid_rate_gg_1998_ipn`` implements NEST's ``sigmoid_rate_gg_1998_ipn`` model:
    a continuous-time rate neuron driven by stochastic input noise and shaped by the
    historical Gancarz-Grossberg (1998) quartic gain function. The model corresponds
    to NEST's ``rate_neuron_ipn`` template instantiated with ``sigmoid_rate_gg_1998``
    nonlinearity. Multiplicative coupling factors are fixed to one for this model variant
    (the flag is kept for API compatibility).

    **1. Continuous-time stochastic dynamics**

    The neuron rate :math:`X(t)` evolves under the Itô stochastic differential equation:

    .. math::

       \tau\,dX(t) = \left[-\lambda X(t) + \mu + \phi(\cdot)\right]dt
       + \left[\sqrt{\tau}\,\sigma\right]dW(t),

    where :math:`\tau > 0` is the intrinsic time constant (ms), :math:`\lambda \ge 0`
    is the passive decay rate, :math:`\mu` is the constant mean drive, :math:`\sigma \ge 0`
    is the input noise amplitude, and :math:`W(t)` is a standard Wiener process.

    **2. Gancarz-Grossberg quartic gain function**

    The gain function :math:`\phi(h)` applies a steep, saturating nonlinearity to
    summed synaptic input :math:`h`:

    .. math::

       \phi(h) = \frac{(g\,h)^4}{0.1^4 + (g\,h)^4},

    where :math:`g > 0` is the gain parameter controlling horizontal scaling. This
    form, introduced by Gancarz & Grossberg (1998) for saccade generation modeling,
    provides faster saturation than standard sigmoid or Hill functions and approaches
    a binary step near threshold. The constant denominator term :math:`0.1^4 = 10^{-4}`
    sets the half-activation point at :math:`h = 0.1/g` when :math:`g=1`.

    Compared with the standard sigmoid :math:`g/(1+e^{-\beta(h-\theta)})`, the quartic
    gain exhibits steeper rise and faster approach to asymptotic bounds, reflecting
    winner-take-all dynamics in reticular formation circuits.

    **3. Numerical integration scheme**

    At each time step :math:`h = dt`, the model applies the stochastic exponential
    Euler (Euler-Maruyama) scheme:

    .. math::

       X_{n+1} = P_1\,X_n + P_2\,(\mu + \mu_{ext} + \phi(\cdot))
       + \sqrt{\frac{\tau\,(-0.5\,\mathrm{expm1}(-2\lambda h/\tau))}{\lambda}}\,\xi_n,

    where :math:`P_1 = e^{-\lambda h/\tau}`,
    :math:`P_2 = -\mathrm{expm1}(-\lambda h/\tau)/\lambda` for :math:`\lambda > 0`,
    and :math:`\xi_n \sim \mathcal{N}(0, 1)`. For :math:`\lambda = 0`, the propagators
    reduce to :math:`P_1 = 1`, :math:`P_2 = h/\tau`, and the noise factor becomes
    :math:`\sqrt{h/\tau}`.

    The external drive :math:`\mu_{ext}` aggregates continuous current inputs and
    delta inputs via ``sum_current_inputs()`` and ``sum_delta_inputs()``.

    **4. Synaptic input routing and delayed events**

    When ``linear_summation=True`` (default), the gain :math:`\phi` applies to the
    total summed synaptic input across all sources:

    .. math::

       \phi(\cdot) = \phi(h_{ex,\mathrm{inst}} + h_{in,\mathrm{inst}}
       + h_{ex,\mathrm{del}} + h_{in,\mathrm{del}}),

    where subscripts denote instantaneous vs. delayed, excitatory vs. inhibitory
    branch sums. When ``linear_summation=False``, the gain applies to each incoming
    rate event **before** weighted summation, matching NEST's per-event nonlinearity
    mode.

    Delayed events are stored in per-step queues (``_delayed_ex_queue``,
    ``_delayed_in_queue``) indexed by target step. Event polarity (excitatory or
    inhibitory) is determined by the sign of the weight field. Events support
    dict, tuple, or scalar formats:

    - ``{'rate': r, 'weight': w, 'delay_steps': d, 'multiplicity': m}``
    - ``(r, w, d, m)`` or shorter tuples (defaults applied)
    - Scalar ``r`` (weight=1, delay_steps=0, multiplicity=1)

    **5. Update ordering**

    Per simulation step (matching NEST ``rate_neuron_ipn`` with ``sigmoid_rate_gg_1998``):

    1. Store outgoing delayed value as current ``rate`` (``delayed_rate``).
    2. Draw noise realization :math:`\xi_n`.
    3. Propagate intrinsic dynamics :math:`-\lambda X` with stochastic exponential Euler.
    4. Read delayed and instantaneous event buffers, schedule new delayed events.
    5. Apply gain function to branch sums (if ``linear_summation=True``) or per event
       (if ``linear_summation=False``).
    6. Apply rectification: if ``rectify_output=True``, clamp ``rate`` to
       :math:`\ge` ``rectify_rate``.
    7. Store outgoing instantaneous value as updated ``rate`` (``instant_rate``).

    **6. Assumptions, constraints, and failure modes**

    - Construction-time validation ensures ``tau > 0``, ``lambda >= 0``, ``sigma >= 0``,
      ``rectify_rate >= 0``.
    - Parameters are scalar or broadcastable to ``self.varshape``.
    - Delayed event ``delay_steps`` must be non-negative integers; instantaneous events
      (``instant_rate_events``) must have ``delay_steps=0`` or omit the field.
    - Multiplicative coupling factors are identically one (``_mult_coupling_ex/in``
      return ``ones_like``). Enabling ``mult_coupling=True`` has no effect for this model.
    - Per-step complexity is :math:`O(|\mathrm{state}| \cdot K)` for ``K`` events per step.

    Parameters
    ----------
    in_size : Size
        Population shape. Supports integer ``n`` (1D with ``n`` neurons), tuple ``(n, m)``
        (2D grid), or higher-dimensional tuples.
    tau : Quantity[ms], optional
        Intrinsic time constant :math:`\tau` of rate dynamics (ms). Must be positive.
        Default ``10 ms``.
    lambda_ : float or array_like, optional
        Passive decay rate :math:`\lambda` (dimensionless). Must be non-negative.
        Default ``1.0``.
    sigma : float or array_like, optional
        Input noise amplitude :math:`\sigma` (dimensionless). Must be non-negative.
        Zero disables stochastic forcing. Default ``1.0``.
    mu : float or array_like, optional
        Constant mean drive :math:`\mu` (dimensionless). Default ``0.0``.
    g : float or array_like, optional
        Gain parameter :math:`g` of the quartic nonlinearity (dimensionless).
        Larger values steepen the gain curve and shift the half-activation point.
        Must be positive. Default ``1.0``.
    mult_coupling : bool, optional
        Multiplicative coupling flag kept for NEST compatibility. For this model
        variant, multiplicative factors are identically one regardless of this
        setting. Default ``False``.
    linear_summation : bool, optional
        If ``True`` (default), apply gain :math:`\phi` to total summed synaptic input.
        If ``False``, apply :math:`\phi` to each event before weighted summation
        (per-event nonlinearity mode). Default ``True``.
    rectify_rate : float or array_like, optional
        Lower bound for rectified output when ``rectify_output=True``.
        Must be non-negative. Default ``0.0``.
    rectify_output : bool, optional
        If ``True``, clamp updated ``rate`` to :math:`\ge` ``rectify_rate`` at each step.
        Default ``False``.
    rate_initializer : Callable[[shape, batch_size], Array], optional
        Initializer for ``rate`` state. Default ``Constant(0.0)``.
    noise_initializer : Callable[[shape, batch_size], Array], optional
        Initializer for ``noise`` state. Default ``Constant(0.0)``.
    name : str, optional
        Module name for identification. Auto-generated if not provided.

    State Variables
    ---------------
    rate : ShortTermState
        Current neuron rate :math:`X(t)` (dimensionless, shape ``varshape + (batch_size,)``).
    noise : ShortTermState
        Last noise realization :math:`\sigma\,\xi_n` (dimensionless).
    instant_rate : ShortTermState
        Instantaneous outgoing rate (alias of ``rate`` after update).
    delayed_rate : ShortTermState
        Delayed outgoing rate (stored from previous step).
    _step_count : ShortTermState
        Internal step counter (int64 scalar) for delayed event scheduling.

    Parameter Mapping
    -----------------

    ===============================  ===============================  =======================
    brainpy.state parameter          NEST parameter                   Notes
    ===============================  ===============================  =======================
    ``tau``                          ``tau``                          Time constant (ms)
    ``lambda_``                      ``lambda``                       Passive decay rate
    ``sigma``                        ``sigma``                        Input noise amplitude
    ``mu``                           ``mu``                           Constant drive
    ``g``                            ``g``                            Gain parameter
    ``mult_coupling``                ``mult_coupling``                (Always 1.0 for this model)
    ``linear_summation``             ``linear_summation``             Gain application mode
    ``rectify_rate``                 ``rectify_rate``                 Lower rectification bound
    ``rectify_output``               ``rectify_output``               Rectification flag
    ``rate_initializer``             ``rate`` (initialization)        Initial rate value
    ===============================  ===============================  =======================

    Recordables
    -----------
    - ``rate`` : Current neuron rate (main output)
    - ``noise`` : Last noise realization

    Receptor Types
    --------------
    - ``RATE`` : index 0 (single receptor port for all rate inputs)

    Notes
    -----
    **Runtime event API**:

    - ``instant_rate_events``: Applied in the current step with ``delay_steps=0``.
      Raises ``ValueError`` if non-zero delay is specified.
    - ``delayed_rate_events``: Queued for delivery after ``delay_steps`` steps
      (default ``delay_steps=1`` if omitted).

    **Event format flexibility**:

    - Dict: ``{'rate': r, 'weight': w, 'delay_steps': d, 'multiplicity': m}``
      (all fields optional except ``rate``).
    - Tuple: ``(rate, weight)``, ``(rate, weight, delay_steps)``, or
      ``(rate, weight, delay_steps, multiplicity)``.
    - Scalar: Interpreted as ``rate`` with weight=1, delay_steps=0, multiplicity=1.
    - Lists of any of the above are flattened and processed sequentially.

    **Weight polarity**: Positive weights contribute to excitatory branch, negative
    to inhibitory branch. Branch separation only affects ``linear_summation=False`` mode
    with ``mult_coupling=True`` (which has no effect for this model).

    Examples
    --------
    Create a population and drive it with noisy input:

    .. code-block:: python

       >>> import brainpy.state as bst
       >>> import saiunit as u
       >>> import brainstate as bs
       >>> pop = bst.sigmoid_rate_gg_1998_ipn(
       ...     in_size=100, tau=20*u.ms, lambda_=0.5, sigma=0.1, g=2.0
       ... )
       >>> pop.init_all_states()
       >>> with bs.environ.context(dt=0.1*u.ms):
       ...     for _ in range(1000):
       ...         r = pop.update(x=0.5)

    Apply delayed rate events:

    .. code-block:: python

       >>> with bs.environ.context(dt=0.1*u.ms):
       ...     # Event at t+5 steps with weight 0.8
       ...     r = pop.update(delayed_rate_events=[{'rate': 1.0, 'weight': 0.8, 'delay_steps': 5}])

    References
    ----------
    .. [1] Gancarz G, Grossberg S (1998). A neural model of the saccade generator
       in the reticular formation. Neural Networks, 11(7):1159-1174.
       DOI: 10.1016/S0893-6080(98)00096-3.
    .. [2] Hahne J, Dahmen D, Schuecker J, Frommer A, Bolten M, Helias M,
       Diesmann M (2017). Integration of continuous-time dynamics in a spiking
       neural network simulator. Frontiers in Neuroinformatics, 11:34.
       DOI: 10.3389/fninf.2017.00034.

    See Also
    --------
    sigmoid_rate_ipn : Standard sigmoid gain function
    lin_rate_ipn : Linear rate neuron variant
    tanh_rate_ipn : Hyperbolic tangent gain function
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        tau: ArrayLike = 10.0 * u.ms,
        lambda_: ArrayLike = 1.0,
        sigma: ArrayLike = 1.0,
        mu: ArrayLike = 0.0,
        g: ArrayLike = 1.0,
        mult_coupling: bool = False,
        linear_summation: bool = True,
        rectify_rate: ArrayLike = 0.0,
        rectify_output: bool = False,
        rate_initializer: Callable = braintools.init.Constant(0.0),
        noise_initializer: Callable = braintools.init.Constant(0.0),
        name: str = None,
    ):
        super().__init__(
            in_size=in_size,
            tau=tau,
            sigma=sigma,
            mu=mu,
            g=g,
            mult_coupling=mult_coupling,
            g_ex=1.0,
            g_in=1.0,
            theta_ex=0.0,
            theta_in=0.0,
            linear_summation=linear_summation,
            rate_initializer=rate_initializer,
            noise_initializer=noise_initializer,
            name=name,
        )
        self.lambda_ = braintools.init.param(lambda_, self.varshape)
        self.rectify_rate = braintools.init.param(rectify_rate, self.varshape)
        self.rectify_output = bool(rectify_output)
        self._validate_parameters()

    @property
    def recordables(self):
        r"""List of state variables available for recording.

        Returns
        -------
        list of str
            ``['rate', 'noise']`` — current rate and last noise realization.
        """
        return ['rate', 'noise']

    @property
    def receptor_types(self):
        r"""Receptor port name-to-index mapping for synaptic input routing.

        Returns
        -------
        dict[str, int]
            ``{'RATE': 0}`` — single receptor port accepting all rate events.
        """
        return {'RATE': 0}

    def _validate_parameters(self):
        r"""Validate construction-time parameter constraints.

        Checks that time constant, decay rate, noise amplitude, and rectification
        rate satisfy physical and numerical consistency requirements.

        Raises
        ------
        ValueError
            If ``tau <= 0`` (non-positive time constant).
        ValueError
            If ``lambda_ < 0`` (negative passive decay rate).
        ValueError
            If ``sigma < 0`` (negative noise amplitude).
        ValueError
            If ``rectify_rate < 0`` (negative rectification bound).
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.tau, self.sigma)):
            return
        if np.any(self.tau <= 0.0 * u.ms):
            raise ValueError('Time constant tau must be > 0.')
        if np.any(self.lambda_ < 0.0):
            raise ValueError('Passive decay rate lambda must be >= 0.')
        if np.any(self.sigma < 0.0):
            raise ValueError('Noise parameter sigma must be >= 0.')
        if np.any(self.rectify_rate < 0.0):
            raise ValueError('Rectifying rate must be >= 0.')

[docs] def init_state(self, **kwargs): r"""Initialize neuron state variables and delayed event queues. Allocates ``rate``, ``noise``, ``instant_rate``, ``delayed_rate``, internal step counter, and per-step delayed event dictionaries for excitatory and inhibitory branches. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- Called automatically by ``init_all_states()``. Initializes: - ``rate`` : from ``rate_initializer`` - ``noise`` : from ``noise_initializer`` - ``instant_rate``, ``delayed_rate`` : copies of initial ``rate`` - ``_step_count`` : int64 scalar starting at 0 - ``_delayed_ex_queue``, ``_delayed_in_queue`` : empty dicts """ rate = braintools.init.param(self.rate_initializer, self.varshape) noise = braintools.init.param(self.noise_initializer, self.varshape) rate_np = self._to_numpy(rate) noise_np = self._to_numpy(noise) self.rate = brainstate.ShortTermState(rate_np) self.noise = brainstate.ShortTermState(noise_np) dftype = brainstate.environ.dftype() self.instant_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True)) self.delayed_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True)) self._step_count = 0 self._delayed_ex_queue = {} self._delayed_in_queue = {}
[docs] def update(self, x=0.0, instant_rate_events=None, delayed_rate_events=None, noise=None, _precomputed_ex=None, _precomputed_in=None): r"""Advance rate neuron dynamics by one time step with stochastic integration. Implements the full update cycle: drain delayed events, schedule new delayed events, aggregate instantaneous events and external inputs, propagate intrinsic dynamics with stochastic exponential Euler, apply quartic gain function, and apply optional rectification. Parameters ---------- x : float or array_like, optional External continuous current input (dimensionless). Broadcasts to ``varshape``. Aggregated via ``sum_current_inputs()``. Default ``0.0``. instant_rate_events : None or list or dict or tuple, optional Instantaneous rate events applied in the current step. Each event must have ``delay_steps=0`` or omit the field. Supports nested lists, dicts, tuples, or scalars. Default ``None`` (no instantaneous events). delayed_rate_events : None or list or dict or tuple, optional Delayed rate events scheduled for future delivery. Each event specifies ``delay_steps >= 0`` (default 1 if omitted). Format matches ``instant_rate_events``. Default ``None`` (no delayed events). noise : float or array_like, optional External noise realization :math:`\xi_n` to override internal random draw. If provided, broadcasts to ``varshape``. If ``None`` (default), draws :math:`\xi_n \sim \mathcal{N}(0, 1)` internally. Returns ------- rate_new : ndarray Updated rate :math:`X_{n+1}` (shape ``varshape`` or ``varshape + (batch,)``). Raises ------ ValueError If ``instant_rate_events`` contain non-zero ``delay_steps``. ValueError If ``delayed_rate_events`` contain negative ``delay_steps``. ValueError If event tuples have unsupported length (must be 2, 3, or 4). Notes ----- **Computational sequence**: 1. Convert time step to milliseconds: :math:`h = dt`. 2. Broadcast parameters (``tau``, ``sigma``, ``mu``, ``g``, ``lambda_``, ``rectify_rate``) to ``state_shape``. 3. Drain delayed queues for current step index, schedule new delayed events from ``delayed_rate_events``. 4. Accumulate ``instant_rate_events`` and delta inputs, separating by weight polarity (excitatory/inhibitory). 5. Aggregate external input ``x`` via ``sum_current_inputs()``. 6. Compute propagators :math:`P_1`, :math:`P_2`, and noise factor based on :math:`\lambda`. 7. Draw or use provided noise :math:`\xi_n`. 8. Propagate :math:`X_{n+1} = P_1 X_n + P_2 (\mu + \mu_{ext}) + \text{noise_factor} \cdot \sigma \xi_n`. 9. Apply gain :math:`\phi` to synaptic inputs (mode depends on ``linear_summation`` and ``mult_coupling``). 10. Add weighted synaptic contributions to :math:`X_{n+1}`. 11. Apply rectification if ``rectify_output=True``. 12. Update state variables and increment step counter. **Gain application modes**: - ``linear_summation=True``, ``mult_coupling=False`` (default): :math:`\phi(h_{ex,\mathrm{inst}} + h_{in,\mathrm{inst}} + h_{ex,\mathrm{del}} + h_{in,\mathrm{del}})`. - ``linear_summation=False``: Gain already applied per event before summation. - ``mult_coupling=True``: Separate excitatory/inhibitory branches with multiplicative factors (always 1.0 for this model). **Delayed event scheduling**: Events with ``delay_steps=d`` are delivered at global step ``current_step + d``. Zero-delay events in ``delayed_rate_events`` are applied immediately (no queueing). **Random number generation**: Uses NumPy's default RNG (``np.random.normal``). For reproducibility, seed the global RNG before simulation. """ h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms)) dftype = brainstate.environ.dftype() state_shape = self.rate.value.shape tau, sigma, mu, g = self._common_parameters_sigmoid_gg_1998(state_shape) lambda_ = self._broadcast_to_state(self._to_numpy(self.lambda_), state_shape) rectify_rate = self._broadcast_to_state(self._to_numpy(self.rectify_rate), state_shape) rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape) if _precomputed_ex is not None: delayed_ex = jnp.asarray(_precomputed_ex, dtype=dftype) delayed_in = jnp.asarray(_precomputed_in, dtype=dftype) instant_ex = jnp.zeros(state_shape, dtype=dftype) instant_in = jnp.zeros(state_shape, dtype=dftype) mu_ext = jnp.zeros(state_shape, dtype=dftype) else: _, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext = ( self._common_inputs_sigmoid_gg_1998( x=x, instant_rate_events=instant_rate_events, delayed_rate_events=delayed_rate_events, g=g, ) ) self._step_count = step_idx + 1 if noise is None: xi = np.random.normal(size=state_shape) else: xi = self._broadcast_to_state(self._to_numpy(noise), state_shape) noise_now = sigma * xi if np.any(lambda_ > 0.0): P1 = np.exp(-lambda_ * h / tau) P2 = -np.expm1(-lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_) input_noise_factor = np.sqrt( -0.5 * np.expm1(-2.0 * lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_) ) zero_lambda = lambda_ == 0.0 if np.any(zero_lambda): P1 = np.where(zero_lambda, 1.0, P1) P2 = np.where(zero_lambda, h / tau, P2) input_noise_factor = np.where(zero_lambda, np.sqrt(h / tau), input_noise_factor) else: P1 = np.ones_like(lambda_) P2 = h / tau input_noise_factor = np.sqrt(h / tau) mu_total = mu + mu_ext rate_new = P1 * rate_prev + P2 * mu_total + input_noise_factor * noise_now H_ex = jnp.ones_like(rate_prev) H_in = jnp.ones_like(rate_prev) if self.mult_coupling: H_ex = self._mult_coupling_ex(rate_prev) H_in = self._mult_coupling_in(rate_prev) if self.linear_summation: if self.mult_coupling: rate_new += P2 * H_ex * self._input(delayed_ex + instant_ex, g) rate_new += P2 * H_in * self._input(delayed_in + instant_in, g) else: rate_new += P2 * self._input(delayed_ex + instant_ex + delayed_in + instant_in, g) else: # Nonlinear transform has already been applied per event in buffer handling. rate_new += P2 * H_ex * (delayed_ex + instant_ex) rate_new += P2 * H_in * (delayed_in + instant_in) if self.rectify_output: rate_new = jnp.where(rate_new < rectify_rate, rectify_rate, rate_new) self.rate.value = rate_new self.noise.value = noise_now self.delayed_rate.value = rate_prev self.instant_rate.value = rate_new return rate_new