Source code for brainpy_state._nest.tanh_rate

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


from typing import Callable

import brainstate
import braintools
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from brainpy_state._nest.lin_rate import _lin_rate_base
from ._utils import is_tracer

__all__ = [
    'tanh_rate_ipn',
    'tanh_rate_opn',
]


class _tanh_rate_base(_lin_rate_base):
    r"""Base class for tanh-rate neurons with shared input transformation logic.

    Provides hyperbolic-tangent nonlinearity and fixed multiplicative coupling
    factors (always 1) for both ``tanh_rate_ipn`` and ``tanh_rate_opn``.
    """
    __module__ = 'brainpy.state'

    def _input(self, h, g, theta):
        r"""Apply hyperbolic-tangent nonlinearity to input.

        Parameters
        ----------
        h : ndarray
            Input value(s) to transform (dimensionless).
        g : ndarray
            Gain parameter (dimensionless).
        theta : ndarray
            Threshold/shift parameter (dimensionless).

        Returns
        -------
        out : ndarray
            Transformed input :math:`\tanh(g(h - \theta))`.
        """
        return np.tanh(g * (h - theta))

    @staticmethod
    def _mult_coupling_ex(rate):
        r"""Multiplicative coupling factor for excitatory inputs (always 1).

        Parameters
        ----------
        rate : ndarray
            Current rate values (unused for tanh_rate).

        Returns
        -------
        out : ndarray
            Array of ones with same shape and dtype as ``rate``.
        """
        dftype = brainstate.environ.dftype()
        return jnp.ones_like(rate, dtype=dftype)

    @staticmethod
    def _mult_coupling_in(rate):
        r"""Multiplicative coupling factor for inhibitory inputs (always 1).

        Parameters
        ----------
        rate : ndarray
            Current rate values (unused for tanh_rate).

        Returns
        -------
        out : ndarray
            Array of ones with same shape and dtype as ``rate``.
        """
        dftype = brainstate.environ.dftype()
        return jnp.ones_like(rate, dtype=dftype)

    def _extract_event_fields(self, ev, default_delay_steps: int):
        r"""Extract event fields from flexible event representation.

        Parse rate events in dict, tuple, or scalar format and return
        normalized components.

        Parameters
        ----------
        ev : dict or tuple or list or scalar
            Rate event specification. Supported formats:

            - dict: ``{'rate': val, 'weight': w, 'delay_steps': d,
              'multiplicity': m}``
            - tuple/list: ``(rate, weight)``, ``(rate, weight, delay_steps)``,
              or ``(rate, weight, delay_steps, multiplicity)``
            - scalar: interpreted as rate with default weight, multiplicity,
              and delay

        default_delay_steps : int
            Default delay in simulation steps when not specified.

        Returns
        -------
        rate : Any
            Rate value (not yet converted to array).
        weight : float
            Synaptic weight.
        multiplicity : float
            Event multiplicity factor.
        delay_steps : int
            Delay in simulation steps.

        Raises
        ------
        ValueError
            If tuple/list has invalid length (not 2, 3, or 4).
        """
        if isinstance(ev, dict):
            rate = ev.get('rate', ev.get('coeff', ev.get('value', 0.0)))
            weight = ev.get('weight', 1.0)
            multiplicity = ev.get('multiplicity', 1.0)
            delay_steps = ev.get('delay_steps', ev.get('delay', default_delay_steps))
        elif isinstance(ev, (tuple, list)):
            if len(ev) == 2:
                rate, weight = ev
                delay_steps = default_delay_steps
                multiplicity = 1.0
            elif len(ev) == 3:
                rate, weight, delay_steps = ev
                multiplicity = 1.0
            elif len(ev) == 4:
                rate, weight, delay_steps, multiplicity = ev
            else:
                raise ValueError('Rate event tuples must have length 2, 3, or 4.')
        else:
            rate = ev
            weight = 1.0
            multiplicity = 1.0
            delay_steps = default_delay_steps

        delay_steps = self._to_int_scalar(delay_steps, name='delay_steps')
        return rate, weight, multiplicity, delay_steps

    def _event_to_ex_in(self, ev, default_delay_steps: int, state_shape, g, theta):
        r"""Convert event to excitatory and inhibitory contributions.

        Extract event components, broadcast to state shape, apply nonlinearity
        if needed, and split by weight sign.

        Parameters
        ----------
        ev : dict or tuple or list or scalar
            Rate event specification.
        default_delay_steps : int
            Default delay when not specified in event.
        state_shape : tuple
            Target shape for broadcasting.
        g : ndarray
            Gain parameter for tanh nonlinearity.
        theta : ndarray
            Threshold parameter for tanh nonlinearity.

        Returns
        -------
        ex : ndarray
            Excitatory (positive-weight) contribution with shape
            ``state_shape``.
        inh : ndarray
            Inhibitory (negative-weight) contribution with shape
            ``state_shape``.
        delay_steps : int
            Extracted delay in simulation steps.

        Notes
        -----
        When ``linear_summation=False``, tanh is applied per event before
        weighting. When ``True``, raw rates are weighted (tanh applied later
        to sums).
        """
        rate, weight, multiplicity, delay_steps = self._extract_event_fields(ev, default_delay_steps)

        rate_np = self._broadcast_to_state(self._to_numpy(rate), state_shape)
        weight_np = self._broadcast_to_state(self._to_numpy(weight), state_shape)
        multiplicity_np = self._broadcast_to_state(self._to_numpy(multiplicity), state_shape)
        dftype = brainstate.environ.dftype()
        weight_sign = self._broadcast_to_state(
            np.asarray(u.get_mantissa(weight), dtype=dftype) >= 0.0,
            state_shape,
        )

        if self.linear_summation:
            weighted_value = rate_np * weight_np * multiplicity_np
        else:
            weighted_value = self._input(rate_np, g, theta) * weight_np * multiplicity_np

        ex = np.where(weight_sign, weighted_value, 0.0)
        inh = np.where(weight_sign, 0.0, weighted_value)
        return ex, inh, delay_steps

    def _accumulate_instant_events_tanh(self, events, state_shape, g, theta):
        r"""Accumulate instant rate events (zero-delay).

        Sum excitatory and inhibitory contributions from all instant events.

        Parameters
        ----------
        events : list or None
            List of rate events to process, or None.
        state_shape : tuple
            Shape for output arrays.
        g : ndarray
            Gain parameter for tanh nonlinearity.
        theta : ndarray
            Threshold parameter for tanh nonlinearity.

        Returns
        -------
        ex : ndarray
            Total excitatory input with shape ``state_shape``.
        inh : ndarray
            Total inhibitory input with shape ``state_shape``.

        Raises
        ------
        ValueError
            If any event specifies non-zero ``delay_steps``.
        """
        dftype = brainstate.environ.dftype()
        ex = np.zeros(state_shape, dtype=dftype)
        inh = np.zeros(state_shape, dtype=dftype)
        for ev in self._coerce_events(events):
            ex_i, inh_i, delay_steps = self._event_to_ex_in(
                ev,
                default_delay_steps=0,
                state_shape=state_shape,
                g=g,
                theta=theta,
            )
            if delay_steps != 0:
                raise ValueError('instant_rate_events must not specify non-zero delay_steps.')
            ex += ex_i
            inh += inh_i
        return ex, inh

    def _schedule_delayed_events_tanh(self, events, step_idx: int, state_shape, g, theta):
        r"""Schedule delayed rate events and return zero-delay contributions.

        Queue events with positive delay into internal buffers and accumulate
        zero-delay events for immediate application.

        Parameters
        ----------
        events : list or None
            List of rate events to schedule, or None.
        step_idx : int
            Current simulation step index.
        state_shape : tuple
            Shape for output arrays.
        g : ndarray
            Gain parameter for tanh nonlinearity.
        theta : ndarray
            Threshold parameter for tanh nonlinearity.

        Returns
        -------
        ex_now : ndarray
            Excitatory contribution from zero-delay events with shape
            ``state_shape``.
        inh_now : ndarray
            Inhibitory contribution from zero-delay events with shape
            ``state_shape``.

        Raises
        ------
        ValueError
            If any event specifies negative ``delay_steps``.

        Notes
        -----
        Events with ``delay_steps > 0`` are queued in
        ``self._delayed_ex_queue`` and ``self._delayed_in_queue`` for delivery
        at ``step_idx + delay_steps``.
        """
        dftype = brainstate.environ.dftype()
        ex_now = np.zeros(state_shape, dtype=dftype)
        inh_now = np.zeros(state_shape, dtype=dftype)
        for ev in self._coerce_events(events):
            ex_i, inh_i, delay_steps = self._event_to_ex_in(
                ev,
                default_delay_steps=1,
                state_shape=state_shape,
                g=g,
                theta=theta,
            )
            if delay_steps < 0:
                raise ValueError('delay_steps for delayed_rate_events must be >= 0.')
            if delay_steps == 0:
                ex_now += ex_i
                inh_now += inh_i
            else:
                target_step = step_idx + delay_steps
                self._queue_add(self._delayed_ex_queue, target_step, ex_i)
                self._queue_add(self._delayed_in_queue, target_step, inh_i)
        return ex_now, inh_now

    def _common_inputs_tanh(self, x, instant_rate_events, delayed_rate_events, g, theta):
        r"""Collect all input contributions for current simulation step.

        Process delayed queues, schedule new delayed events, accumulate instant
        events, and gather external current/delta inputs.

        Parameters
        ----------
        x : Any
            External current input.
        instant_rate_events : list or None
            Rate events to apply immediately.
        delayed_rate_events : list or None
            Rate events to schedule with delay.
        g : ndarray
            Gain parameter for tanh nonlinearity.
        theta : ndarray
            Threshold parameter for tanh nonlinearity.

        Returns
        -------
        state_shape : tuple
            Shape of state variables.
        step_idx : int
            Current simulation step index.
        delayed_ex : ndarray
            Total delayed excitatory input.
        delayed_in : ndarray
            Total delayed inhibitory input.
        instant_ex : ndarray
            Total instant excitatory input (includes delta_inputs).
        instant_in : ndarray
            Total instant inhibitory input (includes delta_inputs).
        mu_ext : ndarray
            External current input (broadcast).

        Notes
        -----
        Delta inputs (from projections) are split by sign and added to instant
        excitatory/inhibitory branches.
        """
        state_shape = self.rate.value.shape
        step_idx = self._step_count

        delayed_ex, delayed_in = self._drain_delayed_queue(step_idx, state_shape)
        delayed_ex_now, delayed_in_now = self._schedule_delayed_events_tanh(
            delayed_rate_events,
            step_idx=step_idx,
            state_shape=state_shape,
            g=g,
            theta=theta,
        )
        delayed_ex = delayed_ex + delayed_ex_now
        delayed_in = delayed_in + delayed_in_now

        instant_ex, instant_in = self._accumulate_instant_events_tanh(
            instant_rate_events,
            state_shape=state_shape,
            g=g,
            theta=theta,
        )

        delta_input = self._broadcast_to_state(self._to_numpy(self.sum_delta_inputs(0.0)), state_shape)
        instant_ex += np.where(delta_input > 0.0, delta_input, 0.0)
        instant_in += np.where(delta_input < 0.0, delta_input, 0.0)

        mu_ext = self._broadcast_to_state(self._to_numpy(self.sum_current_inputs(x, self.rate.value)), state_shape)

        return state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext

    def _common_parameters_tanh(self, state_shape):
        r"""Broadcast model parameters to state shape.

        Convert parameters to plain NumPy arrays and broadcast to match state
        dimensions.

        Parameters
        ----------
        state_shape : tuple
            Target shape for broadcasting.

        Returns
        -------
        tau : ndarray
            Time constant in ms, broadcast to ``state_shape``.
        sigma : ndarray
            Noise scale, broadcast to ``state_shape``.
        mu : ndarray
            Mean drive, broadcast to ``state_shape``.
        g : ndarray
            Gain parameter, broadcast to ``state_shape``.
        theta : ndarray
            Threshold parameter, broadcast to ``state_shape``.
        """
        tau = self._broadcast_to_state(self._to_numpy_ms(self.tau), state_shape)
        sigma = self._broadcast_to_state(self._to_numpy(self.sigma), state_shape)
        mu = self._broadcast_to_state(self._to_numpy(self.mu), state_shape)
        g = self._broadcast_to_state(self._to_numpy(self.g), state_shape)
        theta = self._broadcast_to_state(self._to_numpy(self.theta), state_shape)
        return tau, sigma, mu, g, theta


class tanh_rate_ipn(_tanh_rate_base):
    r"""NEST-compatible ``tanh_rate_ipn`` nonlinear rate neuron with input noise.

    Stochastic rate model with hyperbolic-tangent nonlinearity applied to
    network inputs, matching NEST's ``rate_neuron_ipn`` template instantiated
    with ``tanh_rate`` gain function.

    **1. Model equations**

    The state :math:`X(t)` evolves according to the Langevin equation

    .. math::

       \tau\,dX(t)=
       \left[-\lambda X(t)+\mu+\phi(\cdot)\right]dt
       +\left[\sqrt{\tau}\,\sigma\right]dW(t),

    where the input nonlinearity is

    .. math::

       \phi(h)=\tanh(g(h-\theta)).

    Here :math:`W(t)` is a standard Brownian motion, :math:`\lambda` is the
    passive decay rate, :math:`\mu` is a constant drive, :math:`g` is the
    gain, and :math:`\theta` is the horizontal shift of the tanh function.

    **2. Numerical integration and noise implementation**

    For :math:`\lambda > 0`, integration uses the stochastic exponential Euler
    (exact for the Ornstein-Uhlenbeck process):

    .. math::

       P_1 &= \exp(-\lambda h / \tau), \\
       P_2 &= \frac{1 - P_1}{\lambda}, \\
       X_{n+1} &= P_1 X_n + P_2 \left[\mu + \phi(\cdot)\right]
                  + \sqrt{\frac{1-P_1^2}{2\lambda}} \sigma \xi_n,

    with :math:`\xi_n \sim \mathcal{N}(0,1)`. For :math:`\lambda = 0`, it
    reduces to Euler-Maruyama:

    .. math::

       X_{n+1} = X_n + \frac{h}{\tau} \left[\mu + \phi(\cdot)\right]
                 + \sqrt{\frac{h}{\tau}} \sigma \xi_n.

    **3. Update ordering (matching NEST ``rate_neuron_ipn`` with tanh)**

    Each simulation step proceeds as follows:

    1. Store outgoing delayed value as current ``rate``.
    2. Draw ``noise = sigma * xi`` from the standard normal distribution.
    3. Propagate intrinsic dynamics with stochastic exponential Euler
       (Euler-Maruyama for ``lambda=0``).
    4. Read delayed and instantaneous input buffers.
    5. Apply input contributions:

       - ``linear_summation=True``: apply tanh to summed branch inputs.
       - ``linear_summation=False``: apply tanh per event before summation.

    6. Apply rectification when ``rectify_output=True`` (clamp to
       ``>= rectify_rate``).
    7. Store outgoing instantaneous value as updated ``rate``.

    **4. Timing semantics, assumptions, and constraints**

    - Noise term is white (independent at each step, variance scales with
      :math:`dt`).
    - For :math:`\lambda > 0`, integration exactly preserves stationary
      variance of the OU process.
    - For :math:`\lambda = 0`, the process is non-stationary (variance grows
      linearly with time).
    - Multiplicative coupling factors are fixed to 1 for tanh_rate models
      (``mult_coupling`` has no effect; kept for NEST API compatibility).

    **5. Computational implications**

    Per :meth:`update` call:

    - Random number generation: :math:`O(\prod \mathrm{varshape})`.
    - Exponential operations for :math:`P_1, P_2` when :math:`\lambda > 0`.
    - Event processing is linear in number of events per step.
    - Broadcasting parameters and inputs over ``self.varshape``.

    Parameters
    ----------
    in_size : Size
        Population shape specification consumed by
        :class:`brainstate.nn.Dynamics`. Determines output rate array shape.
    tau : ArrayLike, optional
        Time constant :math:`\tau` of rate dynamics. Must be positive.
        Accepts scalar or array broadcast to ``self.varshape``. Unitful values
        are converted to ms; unitless are interpreted as ms. Default
        ``10.0 * u.ms``.
    lambda_ : ArrayLike, optional
        Passive decay rate :math:`\lambda` (dimensionless). Must be
        non-negative. For :math:`\lambda > 0`, uses stochastic exponential
        Euler; for :math:`\lambda = 0`, uses Euler-Maruyama. Default ``1.0``.
    sigma : ArrayLike, optional
        Input noise scale (dimensionless). Must be non-negative. Scales the
        Wiener increment. Broadcast to ``self.varshape``. Default ``1.0``.
    mu : ArrayLike, optional
        Mean drive :math:`\mu` (dimensionless). Constant additive input.
        Broadcast to ``self.varshape``. Default ``0.0``.
    g : ArrayLike, optional
        Gain of tanh nonlinearity (dimensionless). Controls steepness of tanh.
        Broadcast to ``self.varshape``. Default ``1.0``.
    theta : ArrayLike, optional
        Threshold (horizontal shift) of tanh nonlinearity (dimensionless).
        Shifts the input :math:`h` before applying tanh. Broadcast to
        ``self.varshape``. Default ``0.0``.
    mult_coupling : bool, optional
        Kept for NEST compatibility. For ``tanh_rate`` models, multiplicative
        coupling factors are identically 1, so this switch has no effect.
        Default ``False``.
    linear_summation : bool, optional
        Controls nonlinearity application order. If ``True``, sum inputs then
        apply tanh. If ``False``, apply tanh per event then sum weighted
        results. Default ``True``.
    rectify_rate : ArrayLike, optional
        Lower bound for output rate when ``rectify_output=True``
        (dimensionless). Must be non-negative. Broadcast to
        ``self.varshape``. Default ``0.0``.
    rectify_output : bool, optional
        If ``True``, clamp updated rate to ``>= rectify_rate`` after all
        dynamics. Default ``False``.
    rate_initializer : Callable, optional
        Initializer for state variable ``rate``. Called as
        ``rate_initializer(self.varshape, batch_size)`` in
        :meth:`init_state`. Default ``braintools.init.Constant(0.0)``.
    noise_initializer : Callable, optional
        Initializer for state variable ``noise`` (recording). Default
        ``braintools.init.Constant(0.0)``.
    name : str or None, optional
        Module name passed to :class:`brainstate.nn.Dynamics`.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 22 18 22 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``tau``
         - ``10.0 * u.ms``
         - :math:`\tau`
         - Time constant of rate dynamics (ms).
       * - ``lambda_``
         - ``1.0``
         - :math:`\lambda`
         - Passive decay rate (dimensionless, >= 0).
       * - ``sigma``
         - ``1.0``
         - :math:`\sigma`
         - Input noise scale (dimensionless, >= 0).
       * - ``mu``
         - ``0.0``
         - :math:`\mu`
         - Constant mean drive (dimensionless).
       * - ``g``
         - ``1.0``
         - :math:`g`
         - Gain of tanh nonlinearity (dimensionless).
       * - ``theta``
         - ``0.0``
         - :math:`\theta`
         - Horizontal shift of tanh nonlinearity (dimensionless).
       * - ``rectify_rate``
         - ``0.0``
         - :math:`r_{\min}`
         - Lower clamp bound when ``rectify_output=True``.

    Raises
    ------
    ValueError
        If ``tau <= 0``, ``lambda_ < 0``, ``sigma < 0``, or
        ``rectify_rate < 0``.
    ValueError
        If ``instant_rate_events`` specify non-zero ``delay_steps``, or if
        ``delayed_rate_events`` specify negative ``delay_steps``.

    See Also
    --------
    tanh_rate_opn : Output-noise variant of tanh_rate.
    sigmoid_rate_ipn : Sigmoid nonlinearity with input noise.
    lin_rate_ipn : Linear (identity) nonlinearity with input noise.
    gauss_rate_ipn : Gaussian nonlinearity with input noise.

    Notes
    -----
    **Runtime events:**

    - ``instant_rate_events`` : applied in the current step (``delay_steps=0``).
    - ``delayed_rate_events`` : scheduled with integer ``delay_steps >= 0``.

    **Event format** supports dict or tuple:

    - Dict: ``{'rate': value, 'weight': w, 'delay_steps': d,
      'multiplicity': m}``
    - Tuple: ``(rate, weight)``, ``(rate, weight, delay_steps)``, or
      ``(rate, weight, delay_steps, multiplicity)``

    References
    ----------
    .. [1] NEST Simulator documentation for ``rate_neuron_ipn``:
           https://nest-simulator.readthedocs.io/en/stable/models/rate_neuron_ipn.html
    .. [2] Hahne, J., Dahmen, D., Schuecker, J., Frommer, A., Bolten, M.,
           Helias, M., & Diesmann, M. (2017). Integration of continuous-time
           dynamics in a spiking neural network simulator.
           *Frontiers in Neuroinformatics*, 11, 34.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import saiunit as u
       >>> import jax.numpy as jnp
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     net = brainpy.state.tanh_rate_ipn(
       ...         in_size=10,
       ...         tau=10.0 * u.ms,
       ...         lambda_=1.0,
       ...         sigma=0.5,
       ...         mu=0.0,
       ...         g=1.0,
       ...         theta=0.0,
       ...         rectify_output=True,
       ...         rectify_rate=0.0,
       ...     )
       ...     net.init_all_states()
       ...     rate = net.update(x=0.1)
       ...     _ = rate.shape  # (10,)
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        tau: ArrayLike = 10.0 * u.ms,
        lambda_: ArrayLike = 1.0,
        sigma: ArrayLike = 1.0,
        mu: ArrayLike = 0.0,
        g: ArrayLike = 1.0,
        theta: ArrayLike = 0.0,
        mult_coupling: bool = False,
        linear_summation: bool = True,
        rectify_rate: ArrayLike = 0.0,
        rectify_output: bool = False,
        rate_initializer: Callable = braintools.init.Constant(0.0),
        noise_initializer: Callable = braintools.init.Constant(0.0),
        name: str = None,
    ):
        super().__init__(
            in_size=in_size,
            tau=tau,
            sigma=sigma,
            mu=mu,
            g=g,
            mult_coupling=mult_coupling,
            g_ex=1.0,
            g_in=1.0,
            theta_ex=0.0,
            theta_in=0.0,
            linear_summation=linear_summation,
            rate_initializer=rate_initializer,
            noise_initializer=noise_initializer,
            name=name,
        )
        self.theta = braintools.init.param(theta, self.varshape)
        self.lambda_ = braintools.init.param(lambda_, self.varshape)
        self.rectify_rate = braintools.init.param(rectify_rate, self.varshape)
        self.rectify_output = bool(rectify_output)
        self._validate_parameters()

    @property
    def recordables(self):
        r"""List of recordable state variables.

        Returns
        -------
        list of str
            ``['rate', 'noise']``.
        """
        return ['rate', 'noise']

    @property
    def receptor_types(self):
        r"""Mapping of receptor type names to indices.

        Returns
        -------
        dict
            ``{'RATE': 0}`` for rate-based connections.
        """
        return {'RATE': 0}

    def _validate_parameters(self):
        r"""Validate parameter constraints at initialization.

        Raises
        ------
        ValueError
            If ``tau <= 0``, ``lambda_ < 0``, ``sigma < 0``, or
            ``rectify_rate < 0``.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.tau, self.sigma)):
            return
        if np.any(self.tau <= 0.0 * u.ms):
            raise ValueError('Time constant tau must be > 0.')
        if np.any(self.lambda_ < 0.0):
            raise ValueError('Passive decay rate lambda must be >= 0.')
        if np.any(self.sigma < 0.0):
            raise ValueError('Noise parameter sigma must be >= 0.')
        if np.any(self.rectify_rate < 0.0):
            raise ValueError('Rectifying rate must be >= 0.')

[docs] def init_state(self, **kwargs): r"""Initialize state variables and internal buffers. Create ``rate``, ``noise``, ``instant_rate``, ``delayed_rate``, and step counter states. Initialize delay queues. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- State variables are initialized from ``rate_initializer`` and ``noise_initializer``. Both ``instant_rate`` and ``delayed_rate`` are initialized as copies of ``rate``. """ rate = braintools.init.param(self.rate_initializer, self.varshape) noise = braintools.init.param(self.noise_initializer, self.varshape) rate_np = self._to_numpy(rate) noise_np = self._to_numpy(noise) self.rate = brainstate.ShortTermState(rate_np) self.noise = brainstate.ShortTermState(noise_np) dftype = brainstate.environ.dftype() self.instant_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True)) self.delayed_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True)) self._step_count = 0 self._delayed_ex_queue = {} self._delayed_in_queue = {}
[docs] def update(self, x=0.0, instant_rate_events=None, delayed_rate_events=None, noise=None): r"""Advance rate dynamics by one simulation step. Execute stochastic exponential Euler integration with input noise, process delayed and instant events, apply tanh nonlinearity, and optionally rectify output. Parameters ---------- x : ArrayLike, optional External current input (dimensionless). Broadcast to state shape and summed with ``mu``. Default ``0.0``. instant_rate_events : list or None, optional Rate events to apply in the current step (``delay_steps=0``). Each event can be dict, tuple, or scalar. Default None. delayed_rate_events : list or None, optional Rate events to schedule with integer delays (``delay_steps >= 0``). Default None. noise : ArrayLike or None, optional Custom noise samples :math:`\xi` drawn from :math:`\mathcal{N}(0,1)`. If None, random samples are drawn internally. Useful for reproducibility or testing. Default None. Returns ------- rate : ndarray Updated rate values with shape ``self.varshape`` (float64). Notes ----- **Update sequence:** 1. Extract current step index and state shape. 2. Drain queued delayed inputs scheduled for this step. 3. Schedule new delayed events and accumulate zero-delay events. 4. Accumulate instant events and split delta inputs by sign. 5. Draw noise (or use provided ``noise``). 6. Compute stochastic exponential Euler step: - For :math:`\lambda > 0`: use OU-exact propagators. - For :math:`\lambda = 0`: use Euler-Maruyama. 7. Apply tanh nonlinearity to summed inputs (if ``linear_summation=True``) or use pre-transformed per-event values (if ``False``). 8. Rectify output if ``rectify_output=True``. 9. Update state variables and increment step counter. """ h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms)) state_shape = self.rate.value.shape tau, sigma, mu, g, theta = self._common_parameters_tanh(state_shape) lambda_ = self._broadcast_to_state(self._to_numpy(self.lambda_), state_shape) rectify_rate = self._broadcast_to_state(self._to_numpy(self.rectify_rate), state_shape) state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext = self._common_inputs_tanh( x=x, instant_rate_events=instant_rate_events, delayed_rate_events=delayed_rate_events, g=g, theta=theta, ) dftype = brainstate.environ.dftype() rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape) if noise is None: xi = jnp.asarray(np.random.normal(size=state_shape), dtype=dftype) else: xi = jnp.broadcast_to(jnp.asarray(noise, dtype=dftype), state_shape) noise_now = sigma * xi if np.any(lambda_ > 0.0): P1 = np.exp(-lambda_ * h / tau) P2 = -np.expm1(-lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_) input_noise_factor = np.sqrt( -0.5 * np.expm1(-2.0 * lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_) ) zero_lambda = lambda_ == 0.0 if np.any(zero_lambda): P1 = np.where(zero_lambda, 1.0, P1) P2 = np.where(zero_lambda, h / tau, P2) input_noise_factor = np.where(zero_lambda, np.sqrt(h / tau), input_noise_factor) else: P1 = np.ones_like(lambda_) P2 = h / tau input_noise_factor = np.sqrt(h / tau) mu_total = mu + mu_ext rate_new = P1 * rate_prev + P2 * mu_total + input_noise_factor * noise_now H_ex = jnp.ones_like(rate_prev) H_in = jnp.ones_like(rate_prev) if self.mult_coupling: H_ex = self._mult_coupling_ex(rate_prev) H_in = self._mult_coupling_in(rate_prev) if self.linear_summation: if self.mult_coupling: rate_new += P2 * H_ex * self._input(delayed_ex + instant_ex, g, theta) rate_new += P2 * H_in * self._input(delayed_in + instant_in, g, theta) else: rate_new += P2 * self._input(delayed_ex + instant_ex + delayed_in + instant_in, g, theta) else: # Nonlinear transform has already been applied per event in buffer handling. rate_new += P2 * H_ex * (delayed_ex + instant_ex) rate_new += P2 * H_in * (delayed_in + instant_in) if self.rectify_output: rate_new = jnp.where(rate_new < rectify_rate, rectify_rate, rate_new) self.rate.value = rate_new self.noise.value = noise_now self.delayed_rate.value = rate_prev self.instant_rate.value = rate_new self._step_count = step_idx + 1 return rate_new
class tanh_rate_opn(_tanh_rate_base): r"""NEST-compatible ``tanh_rate_opn`` nonlinear rate neuron with output noise. Deterministic rate model with output-coupled additive noise and hyperbolic-tangent nonlinearity applied to network inputs, matching NEST's ``rate_neuron_opn`` template instantiated with ``tanh_rate`` gain function. **1. Model equations** The internal state :math:`X(t)` evolves deterministically .. math:: \tau\frac{dX(t)}{dt}=-X(t)+\mu+\phi(\cdot), where the input nonlinearity is .. math:: \phi(h)=\tanh(g(h-\theta)). The observed rate includes white noise added at the output: .. math:: X_\mathrm{noisy}(t)=X(t)+\sqrt{\frac{\tau}{h}}\sigma\xi(t), with :math:`\xi(t) \sim \mathcal{N}(0,1)` and :math:`h=dt` the simulation step size. The noise is scaled by :math:`\sqrt{\tau/h}` so that its variance is independent of the step size, matching NEST's implementation. **2. Numerical integration** Deterministic exponential Euler integration: .. math:: P_1 &= \exp(-h / \tau), \\ P_2 &= 1 - P_1, \\ X_{n+1} &= P_1 X_n + P_2 \left[\mu + \phi(\cdot)\right]. The noisy rate for outgoing communication is computed as .. math:: X_{\mathrm{noisy},n} = X_n + \sqrt{\frac{\tau}{h}} \sigma \xi_n. **3. Update ordering (matching NEST ``rate_neuron_opn`` with tanh)** Each simulation step proceeds as follows: 1. Draw ``noise = sigma * xi`` from the standard normal distribution. 2. Build ``noisy_rate`` from the current ``rate`` by adding scaled noise. 3. Store ``noisy_rate`` as delayed outgoing value (for delayed connections). 4. Propagate deterministic intrinsic dynamics with exponential Euler. 5. Add event-driven input contributions: - ``linear_summation=True``: apply tanh to summed branch inputs. - ``linear_summation=False``: apply tanh per event before summation. 6. Store ``noisy_rate`` as instantaneous outgoing value (for instant connections). **4. Timing semantics, assumptions, and constraints** - Noise is added to the output only (internal state :math:`X` remains deterministic). - Noise variance is independent of step size :math:`h` due to :math:`\sqrt{\tau/h}` scaling. - Multiplicative coupling factors are fixed to 1 for tanh_rate models (``mult_coupling`` has no effect; kept for NEST API compatibility). - Unlike input-noise models, there is no rectification option for output-noise models. **5. Computational implications** Per :meth:`update` call: - Random number generation: :math:`O(\prod \mathrm{varshape})`. - Exponential operations for :math:`P_1, P_2` (single exp call). - Event processing is linear in number of events per step. - Broadcasting parameters and inputs over ``self.varshape``. Parameters ---------- in_size : Size Population shape specification consumed by :class:`brainstate.nn.Dynamics`. Determines output rate array shape. tau : ArrayLike, optional Time constant :math:`\tau` of rate dynamics. Must be positive. Accepts scalar or array broadcast to ``self.varshape``. Unitful values are converted to ms; unitless are interpreted as ms. Default ``10.0 * u.ms``. sigma : ArrayLike, optional Output noise scale (dimensionless). Must be non-negative. Scales the Gaussian white noise added to the output. Broadcast to ``self.varshape``. Default ``1.0``. mu : ArrayLike, optional Mean drive :math:`\mu` (dimensionless). Constant additive input. Broadcast to ``self.varshape``. Default ``0.0``. g : ArrayLike, optional Gain of tanh nonlinearity (dimensionless). Controls steepness of tanh. Broadcast to ``self.varshape``. Default ``1.0``. theta : ArrayLike, optional Threshold (horizontal shift) of tanh nonlinearity (dimensionless). Shifts the input :math:`h` before applying tanh. Broadcast to ``self.varshape``. Default ``0.0``. mult_coupling : bool, optional Kept for NEST compatibility. For ``tanh_rate`` models, multiplicative coupling factors are identically 1, so this switch has no effect. Default ``False``. linear_summation : bool, optional Controls nonlinearity application order. If ``True``, sum inputs then apply tanh. If ``False``, apply tanh per event then sum weighted results. Default ``True``. rate_initializer : Callable, optional Initializer for state variable ``rate``. Called as ``rate_initializer(self.varshape, batch_size)`` in :meth:`init_state`. Default ``braintools.init.Constant(0.0)``. noise_initializer : Callable, optional Initializer for state variable ``noise`` (recording). Default ``braintools.init.Constant(0.0)``. noisy_rate_initializer : Callable, optional Initializer for state variable ``noisy_rate`` (output with noise). Default ``braintools.init.Constant(0.0)``. name : str or None, optional Module name passed to :class:`brainstate.nn.Dynamics`. Parameter Mapping ----------------- .. list-table:: Parameter mapping to model symbols :header-rows: 1 :widths: 22 18 22 38 * - Parameter - Default - Math symbol - Semantics * - ``tau`` - ``10.0 * u.ms`` - :math:`\tau` - Time constant of rate dynamics (ms). * - ``sigma`` - ``1.0`` - :math:`\sigma` - Output noise scale (dimensionless, >= 0). * - ``mu`` - ``0.0`` - :math:`\mu` - Constant mean drive (dimensionless). * - ``g`` - ``1.0`` - :math:`g` - Gain of tanh nonlinearity (dimensionless). * - ``theta`` - ``0.0`` - :math:`\theta` - Horizontal shift of tanh nonlinearity (dimensionless). Raises ------ ValueError If ``tau <= 0`` or ``sigma < 0``. ValueError If ``instant_rate_events`` specify non-zero ``delay_steps``, or if ``delayed_rate_events`` specify negative ``delay_steps``. See Also -------- tanh_rate_ipn : Input-noise variant of tanh_rate. sigmoid_rate_opn : Sigmoid nonlinearity with output noise. lin_rate_opn : Linear (identity) nonlinearity with output noise. gauss_rate_opn : Gaussian nonlinearity with output noise. Notes ----- **Runtime events:** - ``instant_rate_events`` : applied in the current step (``delay_steps=0``). - ``delayed_rate_events`` : scheduled with integer ``delay_steps >= 0``. **Event format** supports dict or tuple (identical to :class:`tanh_rate_ipn`): - Dict: ``{'rate': value, 'weight': w, 'delay_steps': d, 'multiplicity': m}`` - Tuple: ``(rate, weight)``, ``(rate, weight, delay_steps)``, or ``(rate, weight, delay_steps, multiplicity)`` References ---------- .. [1] NEST Simulator documentation for ``rate_neuron_opn``: https://nest-simulator.readthedocs.io/en/stable/models/rate_neuron_opn.html .. [2] Hahne, J., Dahmen, D., Schuecker, J., Frommer, A., Bolten, M., Helias, M., & Diesmann, M. (2017). Integration of continuous-time dynamics in a spiking neural network simulator. *Frontiers in Neuroinformatics*, 11, 34. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> import jax.numpy as jnp >>> with brainstate.environ.context(dt=0.1 * u.ms): ... net = brainpy.state.tanh_rate_opn( ... in_size=10, ... tau=10.0 * u.ms, ... sigma=0.5, ... mu=0.0, ... g=1.0, ... theta=0.0, ... ) ... net.init_all_states() ... rate = net.update(x=0.1) ... _ = rate.shape # (10,) ... _ = net.noisy_rate.value.shape # (10,) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, tau: ArrayLike = 10.0 * u.ms, sigma: ArrayLike = 1.0, mu: ArrayLike = 0.0, g: ArrayLike = 1.0, theta: ArrayLike = 0.0, mult_coupling: bool = False, linear_summation: bool = True, rate_initializer: Callable = braintools.init.Constant(0.0), noise_initializer: Callable = braintools.init.Constant(0.0), noisy_rate_initializer: Callable = braintools.init.Constant(0.0), name: str = None, ): super().__init__( in_size=in_size, tau=tau, sigma=sigma, mu=mu, g=g, mult_coupling=mult_coupling, g_ex=1.0, g_in=1.0, theta_ex=0.0, theta_in=0.0, linear_summation=linear_summation, rate_initializer=rate_initializer, noise_initializer=noise_initializer, name=name, ) self.theta = braintools.init.param(theta, self.varshape) self.noisy_rate_initializer = noisy_rate_initializer self._validate_parameters() @property def recordables(self): r"""List of recordable state variables. Returns ------- list of str ``['rate', 'noise', 'noisy_rate']``. """ return ['rate', 'noise', 'noisy_rate'] @property def receptor_types(self): r"""Mapping of receptor type names to indices. Returns ------- dict ``{'RATE': 0}`` for rate-based connections. """ return {'RATE': 0} def _validate_parameters(self): r"""Validate parameter constraints at initialization. Raises ------ ValueError If ``tau <= 0`` or ``sigma < 0``. """ # Skip validation when parameters are JAX tracers (e.g. during jit). if any(is_tracer(v) for v in (self.tau, self.sigma)): return if np.any(self.tau <= 0.0 * u.ms): raise ValueError('Time constant tau must be > 0.') if np.any(self.sigma < 0.0): raise ValueError('Noise parameter sigma must be >= 0.')
[docs] def init_state(self, **kwargs): r"""Initialize state variables and internal buffers. Create ``rate``, ``noise``, ``noisy_rate``, ``instant_rate``, ``delayed_rate``, and step counter states. Initialize delay queues. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- State variables are initialized from ``rate_initializer``, ``noise_initializer``, and ``noisy_rate_initializer``. Both ``instant_rate`` and ``delayed_rate`` are initialized as copies of ``noisy_rate``. """ rate = braintools.init.param(self.rate_initializer, self.varshape) noise = braintools.init.param(self.noise_initializer, self.varshape) noisy_rate = braintools.init.param(self.noisy_rate_initializer, self.varshape) rate_np = self._to_numpy(rate) noise_np = self._to_numpy(noise) noisy_rate_np = self._to_numpy(noisy_rate) self.rate = brainstate.ShortTermState(rate_np) self.noise = brainstate.ShortTermState(noise_np) self.noisy_rate = brainstate.ShortTermState(noisy_rate_np) dftype = brainstate.environ.dftype() self.instant_rate = brainstate.ShortTermState(np.array(noisy_rate_np, dtype=dftype, copy=True)) self.delayed_rate = brainstate.ShortTermState(np.array(noisy_rate_np, dtype=dftype, copy=True)) self._step_count = 0 self._delayed_ex_queue = {} self._delayed_in_queue = {}
[docs] def update(self, x=0.0, instant_rate_events=None, delayed_rate_events=None, noise=None): r"""Advance rate dynamics by one simulation step. Execute deterministic exponential Euler integration, add output noise, process delayed and instant events, and apply tanh nonlinearity. Parameters ---------- x : ArrayLike, optional External current input (dimensionless). Broadcast to state shape and summed with ``mu``. Default ``0.0``. instant_rate_events : list or None, optional Rate events to apply in the current step (``delay_steps=0``). Each event can be dict, tuple, or scalar. Default None. delayed_rate_events : list or None, optional Rate events to schedule with integer delays (``delay_steps >= 0``). Default None. noise : ArrayLike or None, optional Custom noise samples :math:`\xi` drawn from :math:`\mathcal{N}(0,1)`. If None, random samples are drawn internally. Useful for reproducibility or testing. Default None. Returns ------- rate : ndarray Updated deterministic rate values with shape ``self.varshape`` (float64). Note: the communicated output is ``noisy_rate``, not ``rate``. Notes ----- **Update sequence:** 1. Draw noise (or use provided ``noise``). 2. Compute ``noisy_rate`` from current ``rate`` by adding scaled noise: :math:`X_{\mathrm{noisy}} = X + \sqrt{\tau/h} \sigma \xi`. 3. Store ``noisy_rate`` as delayed outgoing value. 4. Drain queued delayed inputs scheduled for this step. 5. Schedule new delayed events and accumulate zero-delay events. 6. Accumulate instant events and split delta inputs by sign. 7. Compute deterministic exponential Euler step: :math:`X_{n+1} = P_1 X_n + P_2 [\mu + \phi(\cdot)]`. 8. Apply tanh nonlinearity to summed inputs (if ``linear_summation=True``) or use pre-transformed per-event values (if ``False``). 9. Store ``noisy_rate`` as instantaneous outgoing value. 10. Update state variables and increment step counter. """ h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms)) state_shape = self.rate.value.shape tau, sigma, mu, g, theta = self._common_parameters_tanh(state_shape) state_shape, step_idx, delayed_ex, delayed_in, instant_ex, instant_in, mu_ext = self._common_inputs_tanh( x=x, instant_rate_events=instant_rate_events, delayed_rate_events=delayed_rate_events, g=g, theta=theta, ) dftype = brainstate.environ.dftype() rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape) if noise is None: xi = jnp.asarray(np.random.normal(size=state_shape), dtype=dftype) else: xi = jnp.broadcast_to(jnp.asarray(noise, dtype=dftype), state_shape) noise_now = sigma * xi P1 = np.exp(-h / tau) P2 = -np.expm1(-h / tau) output_noise_factor = np.sqrt(tau / h) noisy_rate = rate_prev + output_noise_factor * noise_now mu_total = mu + mu_ext rate_new = P1 * rate_prev + P2 * mu_total H_ex = jnp.ones_like(rate_prev) H_in = jnp.ones_like(rate_prev) if self.mult_coupling: H_ex = self._mult_coupling_ex(noisy_rate) H_in = self._mult_coupling_in(noisy_rate) if self.linear_summation: if self.mult_coupling: rate_new += P2 * H_ex * self._input(delayed_ex + instant_ex, g, theta) rate_new += P2 * H_in * self._input(delayed_in + instant_in, g, theta) else: rate_new += P2 * self._input(delayed_ex + instant_ex + delayed_in + instant_in, g, theta) else: # Nonlinear transform has already been applied per event in buffer handling. rate_new += P2 * H_ex * (delayed_ex + instant_ex) rate_new += P2 * H_in * (delayed_in + instant_in) self.rate.value = rate_new self.noise.value = noise_now self.noisy_rate.value = noisy_rate self.delayed_rate.value = noisy_rate self.instant_rate.value = noisy_rate self._step_count = step_idx + 1 return rate_new