Source code for brainpy_state._nest.rate_neuron_opn

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


from typing import Callable

import brainstate
import braintools
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from brainpy_state._nest.lin_rate import _lin_rate_base
from ._utils import is_tracer

__all__ = [
    'rate_neuron_opn',
]


class rate_neuron_opn(_lin_rate_base):
    r"""NEST-compatible ``rate_neuron_opn`` output-noise rate-neuron template.

    ``rate_neuron_opn`` implements the NEST template model
    ``rate_neuron_opn<TNonlinearities>`` with the deterministic dynamics

    .. math::

       \tau \frac{dX(t)}{dt}
       = -X(t) + \mu + I_\mathrm{net}(t),

    and output noise applied after the nonlinearity:

    .. math::

       X_\mathrm{noisy}(t)
       = X(t) + \sqrt{\frac{\tau}{h}}\,\sigma\,\xi(t),

    where :math:`X(t)` is the deterministic rate state, :math:`\tau` is the
    time constant, :math:`\mu` is the mean drive, :math:`\sigma\ge 0` is the
    output-noise strength, :math:`h` is the simulation time step, and
    :math:`\xi(t)\sim\mathcal{N}(0,1)` is standard Gaussian white noise
    approximated as piecewise constant over :math:`h`.

    With default callables this is equivalent to NEST ``lin_rate_opn``:

    - ``input(h) = g * h``
    - ``mult_coupling_ex(rate) = g_ex * (theta_ex - rate)``
    - ``mult_coupling_in(rate) = g_in * (theta_in + rate)``

    Mathematical Description
    ------------------------

    **1. Continuous-Time Deterministic Dynamics**

    The deterministic rate state :math:`X(t)` evolves according to

    .. math::

       \tau \frac{dX(t)}{dt} = -X(t) + \mu + I_\mathrm{net}(t),

    where :math:`\tau>0` is the time constant and :math:`I_\mathrm{net}(t)` is
    the network input decomposed as

    .. math::

       I_\mathrm{net}(t) = H_\mathrm{ex}(X_\mathrm{noisy}) \cdot g(I_\mathrm{ex}(t))
                         + H_\mathrm{in}(X_\mathrm{noisy}) \cdot g(I_\mathrm{in}(t)),

    where:

    - :math:`I_\mathrm{ex}(t)` and :math:`I_\mathrm{in}(t)` are excitatory and
      inhibitory synaptic input branches.
    - :math:`g(\cdot)` is the input nonlinearity. Default: :math:`g(h)=g\,h`.
    - :math:`H_\mathrm{ex}(X_\mathrm{noisy})` and
      :math:`H_\mathrm{in}(X_\mathrm{noisy})` are optional multiplicative
      coupling factors dependent on the *noisy* rate. Default:
      :math:`H_\mathrm{ex}=g_\mathrm{ex}(\theta_\mathrm{ex}-X_\mathrm{noisy})`,
      :math:`H_\mathrm{in}=g_\mathrm{in}(\theta_\mathrm{in}+X_\mathrm{noisy})`.
      Only active if ``mult_coupling=True``.

    The ``linear_summation`` switch controls whether the nonlinearity is
    applied to the summed input or to individual synaptic branches:

    - ``linear_summation=True``:
      :math:`I_\mathrm{net}(t) = H\cdot g(I_\mathrm{ex}+I_\mathrm{in})`.
    - ``linear_summation=False``:
      :math:`I_\mathrm{net}(t) = H_\mathrm{ex}\cdot g(I_\mathrm{ex})
      + H_\mathrm{in}\cdot g(I_\mathrm{in})`.

    **2. Output Noise (Postsynaptic Noise Model)**

    Output noise is added *after* the deterministic dynamics, creating a noisy
    observation of the rate:

    .. math::

       X_\mathrm{noisy}(t) = X(t) + \sqrt{\frac{\tau}{h}}\,\sigma\,\xi(t),

    where :math:`\xi(t)\sim\mathcal{N}(0,1)` is standard Gaussian white noise.
    The scaling factor :math:`\sqrt{\tau/h}` ensures that the noise amplitude
    is independent of the discretization time step :math:`h` in the limit
    :math:`h\to 0`.

    **Critical difference from input-noise model**: The noisy rate
    :math:`X_\mathrm{noisy}` is used for *multiplicative coupling* evaluation
    (if ``mult_coupling=True``) and as the *outgoing signal* to downstream
    neurons, but the noise does *not* feed back into the deterministic
    dynamics. This contrasts with the input-noise variant (``rate_neuron_ipn``)
    where noise enters the differential equation directly.

    **3. Discrete-Time Integration**

    For time step :math:`h=dt` (in ms), the deterministic part uses exponential
    Euler integration (exact for the linear ODE):

    .. math::

       X_{n+1} = P_1 X_n + P_2 (\mu + I_\mathrm{net,n}),

    where

    .. math::

       P_1 = \exp(-h/\tau), \quad P_2 = 1 - P_1 = -\mathrm{expm1}(-h/\tau).

    Output noise is added independently at each step:

    .. math::

       X_\mathrm{noisy,n} = X_n + \sqrt{\frac{\tau}{h}}\,\sigma\,\xi_n,

    where :math:`\xi_n\sim\mathcal{N}(0,1)` is drawn at each step.

    **4. Update Ordering (Matching NEST ``rate_neuron_opn_impl.h``)**

    Per simulation step:

    1. Draw noise sample :math:`\xi_n\sim\mathcal{N}(0,1)`, compute
       :math:`\mathrm{noise}_n = \sigma\,\xi_n`.
    2. Compute noisy rate:
       :math:`X_\mathrm{noisy,n} = X_n + \sqrt{\tau/h}\,\mathrm{noise}_n`.
    3. Propagate deterministic intrinsic dynamics:
       :math:`X' = P_1 X_n + P_2 (\mu + \mu_\mathrm{ext})`.
    4. Read delayed and instantaneous event buffers.
    5. Apply network input according to NEST semantics:

       - ``linear_summation=True``: nonlinearity applied to summed branch input
         during update.
       - ``linear_summation=False``: nonlinearity applied per incoming event
         while buffering (handled in event processing).

    6. If ``mult_coupling=True``, multiplicative coupling factors
       :math:`H_\mathrm{ex}(X_\mathrm{noisy,n})` and
       :math:`H_\mathrm{in}(X_\mathrm{noisy,n})` are evaluated at the *noisy*
       rate (matching NEST ``rate_neuron_opn_impl.h``).
    7. Store updated ``rate``, ``noise``, and expose ``noisy_rate`` as
       outgoing delayed/instantaneous event value.

    **5. Stability Constraints and Computational Implications**

    - Construction enforces :math:`\tau>0`, :math:`\sigma\ge 0`.
    - The deterministic dynamics are unconditionally stable (exponential
      relaxation to :math:`\mu + I_\mathrm{net}` with time constant :math:`\tau`).
    - Output noise does not affect stability but may violate rate bounds; no
      automatic rectification is provided (unlike ``rate_neuron_ipn``).
    - Noise variance scales as :math:`\tau\sigma^2/h` per step. For fixed
      :math:`\tau` and :math:`\sigma`, this diverges as :math:`h\to 0`,
      reflecting the white-noise nature of :math:`\xi(t)`.
    - The exponential Euler scheme is numerically stable for all :math:`h>0`.
    - Per-call cost is :math:`O(\prod\mathrm{varshape})` with vectorized NumPy
      operations in ``float64`` for coefficient evaluation and state update.

    Parameters
    ----------
    in_size : Size
        Population shape specification (tuple of int or single int). All
        per-neuron parameters are broadcast to ``self.varshape``. For example,
        ``in_size=10`` creates 10 neurons, ``in_size=(4, 5)`` creates a 4×5
        grid.
    tau : ArrayLike, optional
        Time constant :math:`\tau` (saiunit quantity with ms dimension).
        Scalar or array broadcastable to ``self.varshape``. Must be :math:`>0`.
        Controls the exponential relaxation rate of the deterministic dynamics.
        Default: ``10.0 * u.ms``.
    sigma : ArrayLike, optional
        Output-noise scale :math:`\sigma` (dimensionless scalar or array).
        Broadcastable to ``self.varshape``. Must be :math:`\ge 0`. Determines
        the standard deviation of the Gaussian noise added to the output rate.
        Default: ``1.0``.
    mu : ArrayLike, optional
        Mean drive :math:`\mu` (dimensionless scalar or array). Broadcastable
        to ``self.varshape``. External constant input to the rate dynamics,
        added to the network input. Default: ``0.0``.
    g : ArrayLike, optional
        Linear gain parameter :math:`g` (dimensionless scalar or array).
        Broadcastable to ``self.varshape``. Used by the default input
        nonlinearity :math:`g(h)=g\,h`. Ignored if ``input_nonlinearity`` is
        provided. Default: ``1.0``.
    mult_coupling : bool, optional
        Enable multiplicative coupling (rate-dependent synaptic efficacy). If
        ``True``, applies :math:`H_\mathrm{ex}(X_\mathrm{noisy})` and
        :math:`H_\mathrm{in}(X_\mathrm{noisy})` to synaptic inputs, evaluated
        at the *noisy* rate. If ``False``, :math:`H_\mathrm{ex}=H_\mathrm{in}=1`.
        Default: ``False``.
    g_ex : ArrayLike, optional
        Excitatory multiplicative coupling gain :math:`g_\mathrm{ex}`
        (dimensionless scalar or array). Broadcastable to ``self.varshape``.
        Only used if ``mult_coupling=True``. Default: ``1.0``.
    g_in : ArrayLike, optional
        Inhibitory multiplicative coupling gain :math:`g_\mathrm{in}`
        (dimensionless scalar or array). Broadcastable to ``self.varshape``.
        Only used if ``mult_coupling=True``. Default: ``1.0``.
    theta_ex : ArrayLike, optional
        Excitatory coupling reference rate :math:`\theta_\mathrm{ex}`
        (dimensionless scalar or array). Broadcastable to ``self.varshape``.
        Only used if ``mult_coupling=True``. Default: ``0.0``.
    theta_in : ArrayLike, optional
        Inhibitory coupling reference rate :math:`\theta_\mathrm{in}`
        (dimensionless scalar or array). Broadcastable to ``self.varshape``.
        Only used if ``mult_coupling=True``. Default: ``0.0``.
    linear_summation : bool, optional
        NEST switch controlling where the input nonlinearity is applied. If
        ``True``, the nonlinearity is applied to the sum of excitatory and
        inhibitory inputs (post-summation). If ``False``, the nonlinearity is
        applied separately to each input branch before summation (per-branch).
        Default: ``True``.
    input_nonlinearity : Callable[[ArrayLike], ArrayLike] or Callable[[rate_neuron_opn, ArrayLike], ArrayLike] or None, optional
        Custom input nonlinearity :math:`g(\cdot)` replacing the default
        :math:`g(h)=g\,h`. Callable signature can be ``f(h)`` (receives float64
        NumPy array of shape ``state_shape``, returns array of same shape) or
        ``f(model, h)`` (receives model instance and array, returns array).
        Must be vectorized and compatible with NumPy broadcasting. If ``None``,
        uses default linear gain. Default: ``None``.
    mult_coupling_ex_fn : Callable[[ArrayLike], ArrayLike] or Callable[[rate_neuron_opn, ArrayLike], ArrayLike] or None, optional
        Custom excitatory multiplicative coupling function
        :math:`H_\mathrm{ex}(X_\mathrm{noisy})`. Callable signature can be
        ``f(rate)`` or ``f(model, rate)``. Must return array of same shape as
        input. Evaluated at the *noisy* rate. If ``None``, uses default
        :math:`g_\mathrm{ex}(\theta_\mathrm{ex}-X_\mathrm{noisy})`. Default:
        ``None``.
    mult_coupling_in_fn : Callable[[ArrayLike], ArrayLike] or Callable[[rate_neuron_opn, ArrayLike], ArrayLike] or None, optional
        Custom inhibitory multiplicative coupling function
        :math:`H_\mathrm{in}(X_\mathrm{noisy})`. Callable signature can be
        ``f(rate)`` or ``f(model, rate)``. Must return array of same shape as
        input. Evaluated at the *noisy* rate. If ``None``, uses default
        :math:`g_\mathrm{in}(\theta_\mathrm{in}+X_\mathrm{noisy})`. Default:
        ``None``.
    rate_initializer : Callable, optional
        Initializer for the deterministic ``rate`` state variable :math:`X_0`.
        Callable compatible with ``braintools.init`` API (signature:
        ``(shape, batch_size) -> ArrayLike``). Default:
        ``braintools.init.Constant(0.0)``.
    noise_initializer : Callable, optional
        Initializer for the ``noise`` state variable (records last noise sample
        :math:`\sigma\,\xi_{n-1}`). Callable compatible with ``braintools.init``
        API. Default: ``braintools.init.Constant(0.0)``.
    noisy_rate_initializer : Callable, optional
        Initializer for the ``noisy_rate`` state variable
        :math:`X_\mathrm{noisy,0}` and outgoing event values. Callable
        compatible with ``braintools.init`` API. Default:
        ``braintools.init.Constant(0.0)``.
    name : str or None, optional
        Module name for identification in hierarchies. If ``None``,
        auto-generates a unique name. Default: ``None``.

    Parameter Mapping
    -----------------

    The following table maps NEST ``rate_neuron_opn`` / ``lin_rate_opn``
    parameters to brainpy.state equivalents:

    =============================== ========================== ===========
    NEST Parameter                  brainpy.state              Default
    =============================== ========================== ===========
    ``tau``                         ``tau``                    10 ms
    ``sigma``                       ``sigma``                  1.0
    ``mu``                          ``mu``                     0.0
    ``g`` (nonlinearity gain)       ``g``                      1.0
    ``mult_coupling``               ``mult_coupling``          False
    ``g_ex``, ``g_in``              ``g_ex``, ``g_in``         1.0
    ``theta_ex``, ``theta_in``      ``theta_ex``, ``theta_in`` 0.0
    ``linear_summation``            ``linear_summation``       True
    =============================== ========================== ===========

    Attributes
    ----------
    rate : brainstate.ShortTermState
        Deterministic rate state :math:`X_n` (float64 array of shape
        ``self.varshape`` or ``(batch_size,) + self.varshape``). This is the
        noise-free rate variable.
    noise : brainstate.ShortTermState
        Last noise sample :math:`\sigma\,\xi_{n-1}` (float64 array, same shape
        as ``rate``). Records the noise term used in the previous step.
    noisy_rate : brainstate.ShortTermState
        Noisy rate :math:`X_\mathrm{noisy,n} = X_n + \sqrt{\tau/h}\,\mathrm{noise}_n`
        (float64 array, same shape as ``rate``). This is the outgoing signal
        sent to downstream neurons and used for multiplicative coupling
        evaluation.
    instant_rate : brainstate.ShortTermState
        Noisy rate value for instantaneous event propagation (float64 array,
        same shape as ``rate``). Set to ``noisy_rate`` after each update.
    delayed_rate : brainstate.ShortTermState
        Noisy rate value for delayed projections (float64 array, same shape as
        ``rate``). Set to ``noisy_rate`` after each update.
    _step_count : brainstate.ShortTermState
        Internal step counter for delayed event scheduling (int64 scalar).
        Incremented by 1 after each ``update`` call.
    _delayed_ex_queue : dict
        Internal queue mapping ``step_idx`` (int) to accumulated excitatory
        delayed events (float64 array of shape ``state_shape``).
    _delayed_in_queue : dict
        Internal queue mapping ``step_idx`` (int) to accumulated inhibitory
        delayed events (float64 array of shape ``state_shape``).

    Raises
    ------
    ValueError
        If ``tau <= 0`` (checked during ``__init__`` via
        ``_validate_parameters``).
    ValueError
        If ``sigma < 0`` (checked during ``__init__`` via
        ``_validate_parameters``).
    ValueError
        If ``instant_rate_events`` contain non-zero ``delay_steps`` (checked
        during ``update`` via ``_accumulate_instant_events``).
    ValueError
        If ``delayed_rate_events`` contain negative ``delay_steps`` (checked
        during ``update`` via ``_schedule_delayed_events``).
    ValueError
        If event tuples have length other than 2, 3, or 4 (checked during
        ``update`` via ``_extract_event_fields``).

    Notes
    -----
    **Runtime Events**

    Events can be provided to ``update()`` via ``instant_rate_events`` and
    ``delayed_rate_events`` parameters. Each event can be specified as:

    - **Scalar**: Treated as ``rate`` value with ``weight=1.0``.
    - **Tuple**: ``(rate, weight)`` or ``(rate, weight, delay_steps)`` or
      ``(rate, weight, delay_steps, multiplicity)``.
    - **Dict**: Keys ``'rate'``/``'coeff'``/``'value'`` (event value),
      ``'weight'`` (synaptic weight), ``'delay_steps'``/``'delay'`` (integer
      delay in time steps), ``'multiplicity'`` (event count).

    **Sign Convention**: Events with ``weight >= 0`` contribute to the
    excitatory branch; events with ``weight < 0`` contribute to the inhibitory
    branch.

    **Linear Summation Semantics**: For ``linear_summation=False``, event
    values are transformed by the input nonlinearity during buffering (matching
    NEST event handlers). For ``linear_summation=True``, the nonlinearity is
    applied to the summed input during the update step.

    **Comparison to ``rate_neuron_ipn``**

    The ``_opn`` variant uses output noise (applied after nonlinearity and
    transmitted to downstream neurons), while ``_ipn`` uses input noise (applied
    before dynamics propagation, directly affecting the state evolution). This
    leads to different stationary distributions, noise scaling, and stability
    properties. In ``_opn``, noise does not feed back into the deterministic
    dynamics.

    Examples
    --------
    Minimal output-noise rate neuron:

    .. code-block:: python

       >>> import brainpy.state as bst
       >>> import saiunit as u
       >>> model = bst.rate_neuron_opn(in_size=10, tau=20*u.ms, sigma=0.5)
       >>> model.init_all_states(batch_size=1)
       >>> rate = model(x=0.1)  # external drive
       >>> print(rate.shape)
       (1, 10)

    Multiplicative coupling with custom nonlinearity:

    .. code-block:: python

       >>> import numpy as np
       >>> def tanh_nonlin(h):
       ...     return np.tanh(h)
       >>> model = bst.rate_neuron_opn(
       ...     in_size=5,
       ...     tau=10*u.ms,
       ...     sigma=0.3,
       ...     mult_coupling=True,
       ...     g_ex=1.5, theta_ex=1.0,
       ...     input_nonlinearity=tanh_nonlin
       ... )

    Accessing noisy rate output:

    .. code-block:: python

       >>> model = bst.rate_neuron_opn(in_size=3, tau=10*u.ms, sigma=0.2)
       >>> model.init_all_states()
       >>> rate_deterministic = model.update(x=0.5)  # propagates deterministic dynamics
       >>> rate_noisy = model.noisy_rate.value        # includes output noise
       >>> print(rate_noisy.shape)
       (3,)

    References
    ----------
    .. [1] NEST Simulator Documentation: ``rate_neuron_opn``
           https://nest-simulator.readthedocs.io/en/stable/models/rate_neuron_opn.html
    .. [2] Hahne, J., Dahmen, D., Schuecker, J., Frommer, A., Bolten, M.,
           Helias, M., & Diesmann, M. (2017). Integration of continuous-time
           dynamics in a spiking neural network simulator.
           *Frontiers in Neuroinformatics*, 11, 34.

    See Also
    --------
    rate_neuron_ipn : Input-noise variant of the rate neuron template.
    lin_rate : Deterministic linear rate neuron (``sigma=0``).
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        tau: ArrayLike = 10.0 * u.ms,
        sigma: ArrayLike = 1.0,
        mu: ArrayLike = 0.0,
        g: ArrayLike = 1.0,
        mult_coupling: bool = False,
        g_ex: ArrayLike = 1.0,
        g_in: ArrayLike = 1.0,
        theta_ex: ArrayLike = 0.0,
        theta_in: ArrayLike = 0.0,
        linear_summation: bool = True,
        input_nonlinearity: Callable | None = None,
        mult_coupling_ex_fn: Callable | None = None,
        mult_coupling_in_fn: Callable | None = None,
        rate_initializer: Callable = braintools.init.Constant(0.0),
        noise_initializer: Callable = braintools.init.Constant(0.0),
        noisy_rate_initializer: Callable = braintools.init.Constant(0.0),
        name: str = None,
    ):
        super().__init__(
            in_size=in_size,
            tau=tau,
            sigma=sigma,
            mu=mu,
            g=g,
            mult_coupling=mult_coupling,
            g_ex=g_ex,
            g_in=g_in,
            theta_ex=theta_ex,
            theta_in=theta_in,
            linear_summation=linear_summation,
            rate_initializer=rate_initializer,
            noise_initializer=noise_initializer,
            name=name,
        )

        self.input_nonlinearity = input_nonlinearity
        self.mult_coupling_ex_fn = mult_coupling_ex_fn
        self.mult_coupling_in_fn = mult_coupling_in_fn
        self.noisy_rate_initializer = noisy_rate_initializer

        self._validate_parameters()

    @property
    def recordables(self):
        r"""List of state variable names that can be recorded.

        Returns
        -------
        list of str
            ``['rate', 'noise', 'noisy_rate']``.
        """
        return ['rate', 'noise', 'noisy_rate']

    @property
    def receptor_types(self):
        r"""Receptor type dictionary for projection compatibility.

        Returns
        -------
        dict[str, int]
            ``{'RATE': 0}``. Rate neurons have a single receptor type.
        """
        return {'RATE': 0}

    def _validate_parameters(self):
        r"""Validate model parameters at construction time.

        Raises
        ------
        ValueError
            If ``tau <= 0`` or ``sigma < 0``.

        Notes
        -----
        This method is called automatically during ``__init__``.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.tau, self.sigma)):
            return
        if np.any(self.tau <= 0.0 * u.ms):
            raise ValueError('Time constant tau must be > 0.')
        if np.any(self.sigma < 0.0):
            raise ValueError('Noise parameter sigma must be >= 0.')

    def _call_nl(self, fn: Callable, x: np.ndarray):
        r"""Call user-provided nonlinearity with flexible signature.

        Parameters
        ----------
        fn : Callable
            User-provided function with signature ``f(x)`` or ``f(model, x)``.
        x : np.ndarray
            Input array (float64).

        Returns
        -------
        np.ndarray
            Output of ``fn``, coerced to float64 NumPy array.

        Notes
        -----
        Tries ``fn(self, x)`` first (passing model instance), then falls back
        to ``fn(x)`` if signature mismatch occurs.
        """
        try:
            return fn(self, x)
        except TypeError as first_error:
            try:
                return fn(x)
            except TypeError:
                raise first_error

    def _input_transform(self, h: np.ndarray, state_shape):
        r"""Apply input nonlinearity :math:`g(h)`.

        Parameters
        ----------
        h : np.ndarray
            Input value (pre-nonlinearity, float64).
        state_shape : tuple
            Target broadcast shape for output.

        Returns
        -------
        np.ndarray
            Transformed input :math:`g(h)` broadcast to ``state_shape``.

        Notes
        -----
        If ``input_nonlinearity`` is ``None``, uses default :math:`g(h)=g\,h`.
        Otherwise calls user-provided callable.
        r"""
        dftype = brainstate.environ.dftype()
        h_arr = jnp.broadcast_to(jnp.asarray(h, dtype=dftype), state_shape)
        if self.input_nonlinearity is None:
            g = self._broadcast_to_state(self._to_numpy(self.g), state_shape)
            return g * h_arr
        y = self._call_nl(self.input_nonlinearity, h_arr)
        return jnp.broadcast_to(jnp.asarray(y, dtype=dftype), state_shape)

    def _mult_ex_transform(self, rate: np.ndarray, state_shape):
        r"""Compute excitatory multiplicative coupling factor :math:`H_\mathrm{ex}(X_\mathrm{noisy})`.

        Parameters
        ----------
        rate : np.ndarray
            Current noisy rate state :math:`X_\mathrm{noisy}` (float64).
        state_shape : tuple
            Target broadcast shape for output.

        Returns
        -------
        np.ndarray
            Coupling factor :math:`H_\mathrm{ex}(X_\mathrm{noisy})` broadcast to
            ``state_shape``.

        Notes
        -----
        If ``mult_coupling_ex_fn`` is ``None``, uses default
        :math:`g_\mathrm{ex}(\theta_\mathrm{ex}-X_\mathrm{noisy})`. Otherwise
        calls user-provided callable. Evaluated at the *noisy* rate (matching
        NEST ``rate_neuron_opn_impl.h``).
        r"""
        dftype = brainstate.environ.dftype()
        rate_arr = jnp.broadcast_to(jnp.asarray(rate, dtype=dftype), state_shape)
        if self.mult_coupling_ex_fn is None:
            g_ex = self._broadcast_to_state(self._to_numpy(self.g_ex), state_shape)
            theta_ex = self._broadcast_to_state(self._to_numpy(self.theta_ex), state_shape)
            return g_ex * (theta_ex - rate_arr)
        y = self._call_nl(self.mult_coupling_ex_fn, rate_arr)
        return jnp.broadcast_to(jnp.asarray(y, dtype=dftype), state_shape)

    def _mult_in_transform(self, rate: np.ndarray, state_shape):
        r"""Compute inhibitory multiplicative coupling factor :math:`H_\mathrm{in}(X_\mathrm{noisy})`.

        Parameters
        ----------
        rate : np.ndarray
            Current noisy rate state :math:`X_\mathrm{noisy}` (float64).
        state_shape : tuple
            Target broadcast shape for output.

        Returns
        -------
        np.ndarray
            Coupling factor :math:`H_\mathrm{in}(X_\mathrm{noisy})` broadcast to
            ``state_shape``.

        Notes
        -----
        If ``mult_coupling_in_fn`` is ``None``, uses default
        :math:`g_\mathrm{in}(\theta_\mathrm{in}+X_\mathrm{noisy})`. Otherwise
        calls user-provided callable. Evaluated at the *noisy* rate (matching
        NEST ``rate_neuron_opn_impl.h``).
        """
        dftype = brainstate.environ.dftype()
        rate_arr = jnp.broadcast_to(jnp.asarray(rate, dtype=dftype), state_shape)
        if self.mult_coupling_in_fn is None:
            g_in = self._broadcast_to_state(self._to_numpy(self.g_in), state_shape)
            theta_in = self._broadcast_to_state(self._to_numpy(self.theta_in), state_shape)
            return g_in * (theta_in + rate_arr)
        y = self._call_nl(self.mult_coupling_in_fn, rate_arr)
        return jnp.broadcast_to(jnp.asarray(y, dtype=dftype), state_shape)

    def _extract_event_fields(self, ev, default_delay_steps: int):
        r"""Extract ``(rate, weight, multiplicity, delay_steps)`` from event.

        Parameters
        ----------
        ev : scalar, dict, tuple, or list
            Event specification. See class docstring for format.
        default_delay_steps : int
            Default delay if not specified in event.

        Returns
        -------
        rate : ArrayLike
            Event rate value.
        weight : ArrayLike
            Event weight (sign determines excitatory/inhibitory branch).
        multiplicity : ArrayLike
            Event multiplicity factor.
        delay_steps : int
            Integer delay in simulation time steps.

        Raises
        ------
        ValueError
            If tuple/list event has length other than 2, 3, or 4.
        """
        if isinstance(ev, dict):
            rate = ev.get('rate', ev.get('coeff', ev.get('value', 0.0)))
            weight = ev.get('weight', 1.0)
            multiplicity = ev.get('multiplicity', 1.0)
            delay_steps = ev.get('delay_steps', ev.get('delay', default_delay_steps))
        elif isinstance(ev, (tuple, list)):
            if len(ev) == 2:
                rate, weight = ev
                delay_steps = default_delay_steps
                multiplicity = 1.0
            elif len(ev) == 3:
                rate, weight, delay_steps = ev
                multiplicity = 1.0
            elif len(ev) == 4:
                rate, weight, delay_steps, multiplicity = ev
            else:
                raise ValueError('Rate event tuples must have length 2, 3, or 4.')
        else:
            rate = ev
            weight = 1.0
            multiplicity = 1.0
            delay_steps = default_delay_steps

        delay_steps = self._to_int_scalar(delay_steps, name='delay_steps')
        return rate, weight, multiplicity, delay_steps

    def _event_to_ex_in(self, ev, default_delay_steps: int, state_shape):
        r"""Convert event to excitatory and inhibitory contributions.

        Parameters
        ----------
        ev : scalar, dict, tuple, or list
            Event specification.
        default_delay_steps : int
            Default delay if not specified in event.
        state_shape : tuple
            Target shape for broadcast.

        Returns
        -------
        ex : np.ndarray
            Excitatory contribution (float64 array of shape ``state_shape``).
        inh : np.ndarray
            Inhibitory contribution (float64 array of shape ``state_shape``).
        delay_steps : int
            Integer delay in simulation time steps.

        Notes
        -----
        Sign convention: events with ``weight >= 0`` contribute to ``ex``,
        events with ``weight < 0`` contribute to ``inh``. For
        ``linear_summation=False``, the input nonlinearity is applied during
        this conversion (matching NEST event handling).
        """
        rate, weight, multiplicity, delay_steps = self._extract_event_fields(ev, default_delay_steps)

        rate_np = self._broadcast_to_state(self._to_numpy(rate), state_shape)
        weight_np = self._broadcast_to_state(self._to_numpy(weight), state_shape)
        multiplicity_np = self._broadcast_to_state(self._to_numpy(multiplicity), state_shape)
        dftype = brainstate.environ.dftype()
        weight_sign = self._broadcast_to_state(
            np.asarray(u.math.asarray(weight), dtype=dftype) >= 0.0,
            state_shape,
        )

        if self.linear_summation:
            weighted_value = rate_np * weight_np * multiplicity_np
        else:
            weighted_value = self._input_transform(rate_np, state_shape) * weight_np * multiplicity_np

        ex = np.where(weight_sign, weighted_value, 0.0)
        inh = np.where(weight_sign, 0.0, weighted_value)
        return ex, inh, delay_steps

    def _accumulate_instant_events(self, events, state_shape):
        r"""Accumulate instantaneous events (no delay).

        Parameters
        ----------
        events : None, dict, tuple, list, or iterable
            Instantaneous event specification(s).
        state_shape : tuple
            Target shape for broadcast.

        Returns
        -------
        ex : np.ndarray
            Total excitatory contribution (float64 array of shape ``state_shape``).
        inh : np.ndarray
            Total inhibitory contribution (float64 array of shape ``state_shape``).

        Raises
        ------
        ValueError
            If any event specifies non-zero ``delay_steps``.
        """
        dftype = brainstate.environ.dftype()
        ex = np.zeros(state_shape, dtype=dftype)
        inh = np.zeros(state_shape, dtype=dftype)
        for ev in self._coerce_events(events):
            ex_i, inh_i, delay_steps = self._event_to_ex_in(
                ev,
                default_delay_steps=0,
                state_shape=state_shape,
            )
            if delay_steps != 0:
                raise ValueError('instant_rate_events must not specify non-zero delay_steps.')
            ex += ex_i
            inh += inh_i
        return ex, inh

    def _schedule_delayed_events(self, events, step_idx: int, state_shape):
        r"""Schedule delayed events and return zero-delay contributions.

        Parameters
        ----------
        events : None, dict, tuple, list, or iterable
            Delayed event specification(s).
        step_idx : int
            Current simulation step index.
        state_shape : tuple
            Target shape for broadcast.

        Returns
        -------
        ex_now : np.ndarray
            Excitatory events with ``delay_steps=0`` (float64 array of shape
            ``state_shape``).
        inh_now : np.ndarray
            Inhibitory events with ``delay_steps=0`` (float64 array of shape
            ``state_shape``).

        Raises
        ------
        ValueError
            If any event has negative ``delay_steps``.

        Notes
        -----
        Events with ``delay_steps > 0`` are added to internal delay queues
        ``_delayed_ex_queue`` and ``_delayed_in_queue`` at target step
        ``step_idx + delay_steps``.
        """
        dftype = brainstate.environ.dftype()
        ex_now = np.zeros(state_shape, dtype=dftype)
        inh_now = np.zeros(state_shape, dtype=dftype)
        for ev in self._coerce_events(events):
            ex_i, inh_i, delay_steps = self._event_to_ex_in(
                ev,
                default_delay_steps=1,
                state_shape=state_shape,
            )
            if delay_steps < 0:
                raise ValueError('delay_steps for delayed_rate_events must be >= 0.')
            if delay_steps == 0:
                ex_now += ex_i
                inh_now += inh_i
            else:
                target_step = step_idx + delay_steps
                self._queue_add(self._delayed_ex_queue, target_step, ex_i)
                self._queue_add(self._delayed_in_queue, target_step, inh_i)
        return ex_now, inh_now

    def _common_inputs_template(self, x, instant_rate_events, delayed_rate_events):
        r"""Collect all input contributions for the current update step.

        Parameters
        ----------
        x : ArrayLike
            External drive passed to ``update``.
        instant_rate_events : None, dict, tuple, list, or iterable
            Instantaneous events.
        delayed_rate_events : None, dict, tuple, list, or iterable
            Delayed events.

        Returns
        -------
        state_shape : tuple
            Current state shape (with batch dimension if present).
        step_idx : int
            Current simulation step index.
        delayed_ex : np.ndarray
            Delayed excitatory input arriving at current step (float64 array).
        delayed_in : np.ndarray
            Delayed inhibitory input arriving at current step (float64 array).
        instant_ex : np.ndarray
            Instantaneous excitatory input (float64 array).
        instant_in : np.ndarray
            Instantaneous inhibitory input (float64 array).
        mu_ext : np.ndarray
            External drive from ``x`` and current inputs (float64 array).

        Notes
        -----
        This method combines:

        1. Delayed events arriving at current step (drained from queues).
        2. Newly scheduled delayed events with ``delay_steps=0``.
        3. Instantaneous events.
        4. Delta inputs (sign-separated into excitatory/inhibitory).
        5. Current inputs via ``sum_current_inputs``.
        """
        state_shape = self.rate.value.shape
        ditype = brainstate.environ.ditype()
        step_idx = int(np.asarray(self._step_count.value, dtype=ditype).reshape(-1)[0])

        delayed_ex, delayed_in = self._drain_delayed_queue(step_idx, state_shape)
        delayed_ex_now, delayed_in_now = self._schedule_delayed_events(
            delayed_rate_events,
            step_idx=step_idx,
            state_shape=state_shape,
        )
        delayed_ex = delayed_ex + delayed_ex_now
        delayed_in = delayed_in + delayed_in_now

        instant_ex, instant_in = self._accumulate_instant_events(
            instant_rate_events,
            state_shape=state_shape,
        )

        delta_input = self._broadcast_to_state(self._to_numpy(self.sum_delta_inputs(0.0)), state_shape)
        instant_ex += np.where(delta_input > 0.0, delta_input, 0.0)
        instant_in += np.where(delta_input < 0.0, delta_input, 0.0)

        mu_ext = self._broadcast_to_state(self._to_numpy(self.sum_current_inputs(x, self.rate.value)), state_shape)

        return state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext

[docs] def init_state(self, **kwargs): r"""Initialize all state variables for simulation. This method must be called before the first ``update()`` call. It creates all internal state variables (``rate``, ``noise``, ``noisy_rate``, ``instant_rate``, ``delayed_rate``, ``_step_count``) and resets the delayed event queues. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- **Initialized State Variables** This method initializes the following state variables: - **rate** (``brainstate.ShortTermState``): Deterministic rate state :math:`X_n` (float64 array). Initialized using ``rate_initializer``. - **noise** (``brainstate.ShortTermState``): Last noise sample :math:`\sigma\,\xi_{n-1}` (float64 array). Initialized using ``noise_initializer``. - **noisy_rate** (``brainstate.ShortTermState``): Noisy rate :math:`X_\mathrm{noisy,n} = X_n + \sqrt{\tau/h}\,\mathrm{noise}_n` (float64 array). Initialized using ``noisy_rate_initializer``. - **instant_rate** (``brainstate.ShortTermState``): Noisy rate value for instantaneous event propagation (float64 array). Initialized as a copy of ``noisy_rate``. - **delayed_rate** (``brainstate.ShortTermState``): Noisy rate value for delayed projections (float64 array). Initialized as a copy of ``noisy_rate``. - **_step_count** (``brainstate.ShortTermState``): Internal step counter for delayed event scheduling (int64 scalar). Initialized to ``0``. - **_delayed_ex_queue** (dict): Internal queue mapping ``step_idx`` (int) to accumulated excitatory delayed events (float64 array). Initialized as empty dict. - **_delayed_in_queue** (dict): Internal queue mapping ``step_idx`` (int) to accumulated inhibitory delayed events (float64 array). Initialized as empty dict. **Array Precision** All state arrays are float64 NumPy arrays. All parameters (``tau``, ``sigma``, ``mu``, etc.) are coerced to float64 during initialization. **Repeated Calls** Calling ``init_state()`` multiple times will overwrite existing state variables and clear the delayed event queues. This can be used to reset the model to initial conditions. Examples -------- Initialize a single population: .. code-block:: python >>> import brainpy.state as bst >>> import saiunit as u >>> model = bst.rate_neuron_opn(in_size=10, tau=20*u.ms) >>> model.init_state() >>> print(model.rate.value.shape) (10,) Custom initializers: .. code-block:: python >>> import braintools >>> model = bst.rate_neuron_opn( ... in_size=5, ... tau=10*u.ms, ... rate_initializer=braintools.init.Normal(0.5, 0.1), ... noisy_rate_initializer=braintools.init.Normal(0.5, 0.1) ... ) >>> model.init_state() >>> print(model.rate.value.mean()) # approximately 0.5 See Also -------- update : Perform one simulation step after initialization. """ rate = braintools.init.param(self.rate_initializer, self.varshape) noise = braintools.init.param(self.noise_initializer, self.varshape) noisy_rate = braintools.init.param(self.noisy_rate_initializer, self.varshape) rate_np = self._to_numpy(rate) noise_np = self._to_numpy(noise) noisy_rate_np = self._to_numpy(noisy_rate) self.rate = brainstate.ShortTermState(rate_np) self.noise = brainstate.ShortTermState(noise_np) self.noisy_rate = brainstate.ShortTermState(noisy_rate_np) dftype = brainstate.environ.dftype() self.instant_rate = brainstate.ShortTermState(np.array(noisy_rate_np, dtype=dftype, copy=True)) self.delayed_rate = brainstate.ShortTermState(np.array(noisy_rate_np, dtype=dftype, copy=True)) ditype = brainstate.environ.ditype() self._step_count = brainstate.ShortTermState(np.asarray(0, dtype=ditype)) self._delayed_ex_queue = {} self._delayed_in_queue = {}
[docs] def update(self, x=0.0, instant_rate_events=None, delayed_rate_events=None, noise=None, _precomputed_ex=None, _precomputed_in=None): r"""Perform one simulation step of output-noise rate dynamics. This method implements the core update algorithm for the output-noise rate neuron model. It propagates the deterministic rate dynamics, applies output noise, processes delayed and instantaneous synaptic events, and evaluates optional multiplicative coupling factors. Parameters ---------- x : ArrayLike, optional External drive (dimensionless scalar or array). Broadcastable to ``self.varshape`` or current batch shape. Added to ``mu`` as constant forcing term. Default: ``0.0``. instant_rate_events : None or dict or tuple or list or iterable, optional Instantaneous rate events applied in the current step without delay. Each event can be: - Scalar (treated as ``rate`` value with ``weight=1.0``). - Tuple: ``(rate, weight)`` or ``(rate, weight, delay_steps)`` or ``(rate, weight, delay_steps, multiplicity)``. - Dict with keys ``'rate'``/``'coeff'``/``'value'``, ``'weight'``, ``'delay_steps'``/``'delay'``, ``'multiplicity'``. Events with non-zero ``delay_steps`` will raise ``ValueError``. Default: ``None`` (no instantaneous events). delayed_rate_events : None or dict or tuple or list or iterable, optional Delayed rate events scheduled with integer ``delay_steps`` (units of simulation time step :math:`h`). Same format as ``instant_rate_events``. Events with ``delay_steps=0`` are applied immediately. Events with ``delay_steps>0`` are queued and applied after the specified delay. Negative ``delay_steps`` raise ``ValueError``. Default: ``None`` (no delayed events). noise : ArrayLike or None, optional Externally supplied noise sample :math:`\xi_n` (dimensionless scalar or array). Broadcastable to current batch shape. If ``None`` (default), draws :math:`\xi_n\sim\mathcal{N}(0,1)` internally using ``np.random.normal``. If provided, must have zero mean and unit variance for correct noise amplitude. Default: ``None``. Returns ------- rate_new : np.ndarray Updated deterministic rate state :math:`X_{n+1}` (float64 array of shape ``self.rate.value.shape``). This is the noise-free rate after one simulation step. To access the noisy rate, use ``self.noisy_rate.value``. Raises ------ ValueError If ``instant_rate_events`` contain non-zero ``delay_steps``. ValueError If ``delayed_rate_events`` contain negative ``delay_steps``. ValueError If event tuples have length other than 2, 3, or 4. Notes ----- **Update Algorithm** The method performs the following steps in order: **1. Input Collection** Collect all input contributions for the current step: - Delayed events arriving at current step (drained from internal queues ``_delayed_ex_queue`` and ``_delayed_in_queue``). - Newly scheduled delayed events with ``delay_steps=0`` (from ``delayed_rate_events``). - Instantaneous events (from ``instant_rate_events``). - Delta inputs via ``sum_delta_inputs(0.0)`` (sign-separated into excitatory/inhibitory branches). - Current inputs via ``sum_current_inputs(x, rate)`` (external drive and synaptic inputs). **2. Propagator Coefficients** Compute exponential Euler integration coefficients: .. math:: P_1 = \exp(-h/\tau), \quad P_2 = 1 - P_1 = -\mathrm{expm1}(-h/\tau), where :math:`h` is the simulation time step (in ms) and :math:`\tau` is the time constant. Uses ``np.expm1`` for numerically stable evaluation of :math:`1-e^{-x}`. **3. Output Noise** Draw noise sample :math:`\xi_n\sim\mathcal{N}(0,1)` (or use external ``noise`` parameter) and compute noisy rate: .. math:: X_\mathrm{noisy,n} = X_n + \sqrt{\frac{\tau}{h}}\,\sigma\,\xi_n. The scaling factor :math:`\sqrt{\tau/h}` ensures correct amplitude scaling in the :math:`h\to 0` limit. **4. Deterministic Dynamics Propagation** Propagate the deterministic part of the dynamics: .. math:: X' = P_1 X_n + P_2(\mu + \mu_\mathrm{ext}), where :math:`\mu_\mathrm{ext}` is the external drive from ``x`` and current inputs. **5. Multiplicative Coupling** If ``mult_coupling=True``, evaluate multiplicative coupling factors at the *noisy* rate: .. math:: H_\mathrm{ex}(X_\mathrm{noisy,n}), \quad H_\mathrm{in}(X_\mathrm{noisy,n}). If ``mult_coupling=False``, :math:`H_\mathrm{ex}=H_\mathrm{in}=1`. **6. Network Input Application** Apply network input according to ``linear_summation`` mode: - **linear_summation=True**: Nonlinearity applied to summed input: .. math:: X_{n+1} = X' + P_2 [H_\mathrm{ex}\cdot g(I_\mathrm{ex}) + H_\mathrm{in}\cdot g(I_\mathrm{in})]. If ``mult_coupling=False``, simplifies to: .. math:: X_{n+1} = X' + P_2 g(I_\mathrm{ex} + I_\mathrm{in}). - **linear_summation=False**: Nonlinearity applied per branch during event processing. Network input is already transformed: .. math:: X_{n+1} = X' + P_2 [H_\mathrm{ex}\cdot I_\mathrm{ex} + H_\mathrm{in}\cdot I_\mathrm{in}]. **7. State Updates** Update all state variables: - ``rate``: Deterministic rate :math:`X_{n+1}`. - ``noise``: Noise sample :math:`\sigma\,\xi_n`. - ``noisy_rate``: Noisy rate :math:`X_\mathrm{noisy,n}`. - ``delayed_rate``: Noisy rate for delayed projections. - ``instant_rate``: Noisy rate for instantaneous projections. - ``_step_count``: Incremented by 1. **Key Distinction from ``rate_neuron_ipn``** The noisy rate :math:`X_\mathrm{noisy,n}` is used for multiplicative coupling and as the outgoing signal to downstream neurons, but the noise does *not* feed back into the deterministic dynamics (i.e., :math:`X'` depends only on the noise-free rate :math:`X_n`). This contrasts with the input-noise variant (``rate_neuron_ipn``) where noise enters the differential equation directly. **Numerical Stability** - The exponential Euler scheme is unconditionally stable for all :math:`h>0`. - Uses ``np.expm1(-h/tau)`` to avoid catastrophic cancellation for small :math:`h/\tau`. - Noise scaling :math:`\sqrt{\tau/h}` ensures correct amplitude in the :math:`h\to 0` limit, but per-step noise variance diverges (reflecting white-noise nature). **Failure Modes** - **Invalid time constants**: Caught at construction by ``_validate_parameters`` (enforces :math:`\tau>0`, :math:`\sigma\ge 0`). - **Invalid events**: Raises ``ValueError`` for events with incorrect ``delay_steps`` or tuple length. - **Unbounded rates**: No automatic rectification or clipping. Noisy rate can exceed any bounds. - **NaN propagation**: If input contains NaN, all downstream states will be NaN. No automatic detection or recovery. See Also -------- init_state : Initialize all state variables before first update. rate_neuron_ipn.update : Input-noise variant update method. """ h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms)) dftype = brainstate.environ.dftype() state_shape = self.rate.value.shape tau = self._broadcast_to_state(self._to_numpy_ms(self.tau), state_shape) sigma = self._broadcast_to_state(self._to_numpy(self.sigma), state_shape) mu = self._broadcast_to_state(self._to_numpy(self.mu), state_shape) rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape) if _precomputed_ex is not None: # JIT-compatible path: bypass all Python queue operations. delayed_ex = jnp.asarray(_precomputed_ex, dtype=dftype) delayed_in = jnp.asarray(_precomputed_in, dtype=dftype) instant_ex = jnp.zeros(state_shape, dtype=dftype) instant_in = jnp.zeros(state_shape, dtype=dftype) mu_ext = jnp.zeros(state_shape, dtype=dftype) else: _, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext = self._common_inputs_template( x=x, instant_rate_events=instant_rate_events, delayed_rate_events=delayed_rate_events, ) ditype = brainstate.environ.ditype() self._step_count.value = np.asarray(step_idx + 1, dtype=ditype) if noise is None: xi = np.random.normal(size=state_shape) else: xi = jnp.broadcast_to(jnp.asarray(noise, dtype=dftype), state_shape) noise_now = sigma * xi P1 = np.exp(-h / tau) P2 = -np.expm1(-h / tau) output_noise_factor = np.sqrt(tau / h) noisy_rate = rate_prev + output_noise_factor * noise_now mu_total = mu + mu_ext rate_new = P1 * rate_prev + P2 * mu_total H_ex = jnp.ones(state_shape, dtype=dftype) H_in = jnp.ones(state_shape, dtype=dftype) if self.mult_coupling: H_ex = self._mult_ex_transform(noisy_rate, state_shape) H_in = self._mult_in_transform(noisy_rate, state_shape) if self.linear_summation: if self.mult_coupling: rate_new += P2 * H_ex * self._input_transform(delayed_ex + instant_ex, state_shape) rate_new += P2 * H_in * self._input_transform(delayed_in + instant_in, state_shape) else: rate_new += P2 * self._input_transform( delayed_ex + instant_ex + delayed_in + instant_in, state_shape, ) else: rate_new += P2 * H_ex * (delayed_ex + instant_ex) rate_new += P2 * H_in * (delayed_in + instant_in) self.rate.value = rate_new self.noise.value = noise_now self.noisy_rate.value = noisy_rate self.delayed_rate.value = noisy_rate self.instant_rate.value = noisy_rate return rate_new