Source code for brainpy_state._nest_neuron.tanh_rate

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-


from typing import Callable

import brainstate
import braintools
import brainunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size

from brainpy_state._nest_neuron.lin_rate import _lin_rate_base
from brainpy_state._nest_base.utils import is_tracer

__all__ = [
    'tanh_rate_ipn',
    'tanh_rate_opn',
]


class _tanh_rate_base(_lin_rate_base):
    r"""Base class for tanh-rate neurons with shared input transformation logic.

    Provides hyperbolic-tangent nonlinearity and fixed multiplicative coupling
    factors (always 1) for both ``tanh_rate_ipn`` and ``tanh_rate_opn``.
    """
    __module__ = 'brainpy.state'

    #: φ(h) = tanh(g·(h − θ)) is fixed by the gain and threshold; these identify it
    #: for the ``linear_summation=False`` homogeneity guard.
    _phi_param_names = ('g', 'theta')

    def _activation(self, h):
        """Hyperbolic-tangent gain ``φ(h) = tanh(g·(h − θ))`` (JAX; reads ``self``)."""
        g = u.get_mantissa(self.g)
        theta = u.get_mantissa(self.theta)
        return jnp.tanh(g * (h - theta))

    def _mult_factors(self, rate):
        one = jnp.ones_like(rate)
        return one, one

    def _common_parameters_tanh(self, state_shape):
        r"""Broadcast model parameters to state shape.

        Convert parameters to plain NumPy arrays and broadcast to match state
        dimensions.

        Parameters
        ----------
        state_shape : tuple
            Target shape for broadcasting.

        Returns
        -------
        tau : ndarray
            Time constant in ms, broadcast to ``state_shape``.
        sigma : ndarray
            Noise scale, broadcast to ``state_shape``.
        mu : ndarray
            Mean drive, broadcast to ``state_shape``.
        g : ndarray
            Gain parameter, broadcast to ``state_shape``.
        theta : ndarray
            Threshold parameter, broadcast to ``state_shape``.
        """
        tau = self._broadcast_to_state(self._to_numpy_ms(self.tau), state_shape)
        sigma = self._broadcast_to_state(self._to_numpy(self.sigma), state_shape)
        mu = self._broadcast_to_state(self._to_numpy(self.mu), state_shape)
        g = self._broadcast_to_state(self._to_numpy(self.g), state_shape)
        theta = self._broadcast_to_state(self._to_numpy(self.theta), state_shape)
        return tau, sigma, mu, g, theta


class tanh_rate_ipn(_tanh_rate_base):
    r"""NEST-compatible ``tanh_rate_ipn`` nonlinear rate neuron with input noise.

    Stochastic rate model with hyperbolic-tangent nonlinearity applied to
    network inputs, matching NEST's ``rate_neuron_ipn`` template instantiated
    with ``tanh_rate`` gain function.

    **1. Model equations**

    The state :math:`X(t)` evolves according to the Langevin equation

    .. math::

       \tau\,dX(t)=
       \left[-\lambda X(t)+\mu+\phi(\cdot)\right]dt
       +\left[\sqrt{\tau}\,\sigma\right]dW(t),

    where the input nonlinearity is

    .. math::

       \phi(h)=\tanh(g(h-\theta)).

    Here :math:`W(t)` is a standard Brownian motion, :math:`\lambda` is the
    passive decay rate, :math:`\mu` is a constant drive, :math:`g` is the
    gain, and :math:`\theta` is the horizontal shift of the tanh function.

    **2. Numerical integration and noise implementation**

    For :math:`\lambda > 0`, integration uses the stochastic exponential Euler
    (exact for the Ornstein-Uhlenbeck process):

    .. math::

       P_1 &= \exp(-\lambda h / \tau), \\
       P_2 &= \frac{1 - P_1}{\lambda}, \\
       X_{n+1} &= P_1 X_n + P_2 \left[\mu + \phi(\cdot)\right]
                  + \sqrt{\frac{1-P_1^2}{2\lambda}} \sigma \xi_n,

    with :math:`\xi_n \sim \mathcal{N}(0,1)`. For :math:`\lambda = 0`, it
    reduces to Euler-Maruyama:

    .. math::

       X_{n+1} = X_n + \frac{h}{\tau} \left[\mu + \phi(\cdot)\right]
                 + \sqrt{\frac{h}{\tau}} \sigma \xi_n.

    **3. Update ordering (matching NEST ``rate_neuron_ipn`` with tanh)**

    Each simulation step proceeds as follows:

    1. Store outgoing delayed value as current ``rate``.
    2. Draw ``noise = sigma * xi`` from the standard normal distribution.
    3. Propagate intrinsic dynamics with stochastic exponential Euler
       (Euler-Maruyama for ``lambda=0``).
    4. Read delayed and instantaneous input buffers.
    5. Apply input contributions:

       - ``linear_summation=True``: apply tanh to summed branch inputs.
       - ``linear_summation=False``: apply tanh per event before summation.

    6. Apply rectification when ``rectify_output=True`` (clamp to
       ``>= rectify_rate``).
    7. Store outgoing instantaneous value as updated ``rate``.

    **4. Timing semantics, assumptions, and constraints**

    - Noise term is white (independent at each step, variance scales with
      :math:`dt`).
    - For :math:`\lambda > 0`, integration exactly preserves stationary
      variance of the OU process.
    - For :math:`\lambda = 0`, the process is non-stationary (variance grows
      linearly with time).
    - Multiplicative coupling factors are fixed to 1 for tanh_rate models
      (``mult_coupling`` has no effect; kept for NEST API compatibility).

    **5. Computational implications**

    Per :meth:`update` call:

    - Random number generation: :math:`O(\prod \mathrm{varshape})`.
    - Exponential operations for :math:`P_1, P_2` when :math:`\lambda > 0`.
    - Event processing is linear in number of events per step.
    - Broadcasting parameters and inputs over ``self.varshape``.

    Parameters
    ----------
    in_size : Size
        Population shape specification consumed by
        :class:`brainstate.nn.Dynamics`. Determines output rate array shape.
    tau : ArrayLike, optional
        Time constant :math:`\tau` of rate dynamics. Must be positive.
        Accepts scalar or array broadcast to ``self.varshape``. Unitful values
        are converted to ms; unitless are interpreted as ms. Default
        ``10.0 * u.ms``.
    lambda_ : ArrayLike, optional
        Passive decay rate :math:`\lambda` (dimensionless). Must be
        non-negative. For :math:`\lambda > 0`, uses stochastic exponential
        Euler; for :math:`\lambda = 0`, uses Euler-Maruyama. Default ``1.0``.
    sigma : ArrayLike, optional
        Input noise scale (dimensionless). Must be non-negative. Scales the
        Wiener increment. Broadcast to ``self.varshape``. Default ``1.0``.
    mu : ArrayLike, optional
        Mean drive :math:`\mu` (dimensionless). Constant additive input.
        Broadcast to ``self.varshape``. Default ``0.0``.
    g : ArrayLike, optional
        Gain of tanh nonlinearity (dimensionless). Controls steepness of tanh.
        Broadcast to ``self.varshape``. Default ``1.0``.
    theta : ArrayLike, optional
        Threshold (horizontal shift) of tanh nonlinearity (dimensionless).
        Shifts the input :math:`h` before applying tanh. Broadcast to
        ``self.varshape``. Default ``0.0``.
    mult_coupling : bool, optional
        Kept for NEST compatibility. For ``tanh_rate`` models, multiplicative
        coupling factors are identically 1, so this switch has no effect.
        Default ``False``.
    linear_summation : bool, optional
        Controls nonlinearity application order. If ``True``, sum inputs then
        apply tanh. If ``False``, apply tanh per event then sum weighted
        results. Default ``True``.
    rectify_rate : ArrayLike, optional
        Lower bound for output rate when ``rectify_output=True``
        (dimensionless). Must be non-negative. Broadcast to
        ``self.varshape``. Default ``0.0``.
    rectify_output : bool, optional
        If ``True``, clamp updated rate to ``>= rectify_rate`` after all
        dynamics. Default ``False``.
    rate_initializer : Callable, optional
        Initializer for state variable ``rate``. Called as
        ``rate_initializer(self.varshape, batch_size)`` in
        :meth:`init_state`. Default ``braintools.init.Constant(0.0)``.
    noise_initializer : Callable, optional
        Initializer for state variable ``noise`` (recording). Default
        ``braintools.init.Constant(0.0)``.
    name : str or None, optional
        Module name passed to :class:`brainstate.nn.Dynamics`.

    Parameter Mapping
    -----------------
    .. list-table:: Parameter mapping to model symbols
       :header-rows: 1
       :widths: 22 18 22 38

       * - Parameter
         - Default
         - Math symbol
         - Semantics
       * - ``tau``
         - ``10.0 * u.ms``
         - :math:`\tau`
         - Time constant of rate dynamics (ms).
       * - ``lambda_``
         - ``1.0``
         - :math:`\lambda`
         - Passive decay rate (dimensionless, >= 0).
       * - ``sigma``
         - ``1.0``
         - :math:`\sigma`
         - Input noise scale (dimensionless, >= 0).
       * - ``mu``
         - ``0.0``
         - :math:`\mu`
         - Constant mean drive (dimensionless).
       * - ``g``
         - ``1.0``
         - :math:`g`
         - Gain of tanh nonlinearity (dimensionless).
       * - ``theta``
         - ``0.0``
         - :math:`\theta`
         - Horizontal shift of tanh nonlinearity (dimensionless).
       * - ``rectify_rate``
         - ``0.0``
         - :math:`r_{\min}`
         - Lower clamp bound when ``rectify_output=True``.

    Raises
    ------
    ValueError
        If ``tau <= 0``, ``lambda_ < 0``, ``sigma < 0``, or
        ``rectify_rate < 0``.
    ValueError
        If ``instant_rate_events`` specify non-zero ``delay_steps``, or if
        ``delayed_rate_events`` specify negative ``delay_steps``.

    See Also
    --------
    tanh_rate_opn : Output-noise variant of tanh_rate.
    sigmoid_rate_ipn : Sigmoid nonlinearity with input noise.
    lin_rate_ipn : Linear (identity) nonlinearity with input noise.
    gauss_rate_ipn : Gaussian nonlinearity with input noise.

    Notes
    -----
    **Runtime events:**

    - ``instant_rate_events`` : applied in the current step (``delay_steps=0``).
    - ``delayed_rate_events`` : scheduled with integer ``delay_steps >= 0``.

    **Event format** supports dict or tuple:

    - Dict: ``{'rate': value, 'weight': w, 'delay_steps': d,
      'multiplicity': m}``
    - Tuple: ``(rate, weight)``, ``(rate, weight, delay_steps)``, or
      ``(rate, weight, delay_steps, multiplicity)``

    References
    ----------
    .. [1] NEST Simulator documentation for ``rate_neuron_ipn``:
           https://nest-simulator.readthedocs.io/en/stable/models/rate_neuron_ipn.html
    .. [2] Hahne, J., Dahmen, D., Schuecker, J., Frommer, A., Bolten, M.,
           Helias, M., & Diesmann, M. (2017). Integration of continuous-time
           dynamics in a spiking neural network simulator.
           *Frontiers in Neuroinformatics*, 11, 34.

    Examples
    --------
    .. code-block:: python

       >>> import brainpy
       >>> import brainstate
       >>> import brainunit as u
       >>> import jax.numpy as jnp
       >>> with brainstate.environ.context(dt=0.1 * u.ms):
       ...     net = brainpy.state.tanh_rate_ipn(
       ...         in_size=10,
       ...         tau=10.0 * u.ms,
       ...         lambda_=1.0,
       ...         sigma=0.5,
       ...         mu=0.0,
       ...         g=1.0,
       ...         theta=0.0,
       ...         rectify_output=True,
       ...         rectify_rate=0.0,
       ...     )
       ...     net.init_all_states()
       ...     rate = net.update(x=0.1)
       ...     _ = rate.shape  # (10,)
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        tau: ArrayLike = 10.0 * u.ms,
        lambda_: ArrayLike = 1.0,
        sigma: ArrayLike = 1.0,
        mu: ArrayLike = 0.0,
        g: ArrayLike = 1.0,
        theta: ArrayLike = 0.0,
        mult_coupling: bool = False,
        linear_summation: bool = True,
        rectify_rate: ArrayLike = 0.0,
        rectify_output: bool = False,
        rate_initializer: Callable = braintools.init.Constant(0.0),
        noise_initializer: Callable = braintools.init.Constant(0.0),
        name: str = None,
    ):
        super().__init__(
            in_size=in_size,
            tau=tau,
            sigma=sigma,
            mu=mu,
            g=g,
            mult_coupling=mult_coupling,
            g_ex=1.0,
            g_in=1.0,
            theta_ex=0.0,
            theta_in=0.0,
            linear_summation=linear_summation,
            rate_initializer=rate_initializer,
            noise_initializer=noise_initializer,
            name=name,
        )
        self.theta = braintools.init.param(theta, self.varshape)
        self.lambda_ = braintools.init.param(lambda_, self.varshape)
        self.rectify_rate = braintools.init.param(rectify_rate, self.varshape)
        self.rectify_output = bool(rectify_output)
        self._validate_parameters()

    @property
    def recordables(self):
        r"""List of recordable state variables.

        Returns
        -------
        list of str
            ``['rate', 'noise']``.
        """
        return ['rate', 'noise']

    @property
    def receptor_types(self):
        r"""Mapping of receptor type names to indices.

        Returns
        -------
        dict
            ``{'RATE': 0}`` for rate-based connections.
        """
        return {'RATE': 0}

    def _validate_parameters(self):
        r"""Validate parameter constraints at initialization.

        Raises
        ------
        ValueError
            If ``tau <= 0``, ``lambda_ < 0``, ``sigma < 0``, or
            ``rectify_rate < 0``.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.tau, self.sigma)):
            return
        if np.any(self.tau <= 0.0 * u.ms):
            raise ValueError('Time constant tau must be > 0.')
        if np.any(self.lambda_ < 0.0):
            raise ValueError('Passive decay rate lambda must be >= 0.')
        if np.any(self.sigma < 0.0):
            raise ValueError('Noise parameter sigma must be >= 0.')
        if np.any(self.rectify_rate < 0.0):
            raise ValueError('Rectifying rate must be >= 0.')

[docs] def init_state(self, **kwargs): r"""Initialize state variables and internal buffers. Create ``rate``, ``noise``, ``instant_rate``, ``delayed_rate``, and step counter states. Initialize delay queues. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- State variables are initialized from ``rate_initializer`` and ``noise_initializer``. Both ``instant_rate`` and ``delayed_rate`` are initialized as copies of ``rate``. """ rate = braintools.init.param(self.rate_initializer, self.varshape) noise = braintools.init.param(self.noise_initializer, self.varshape) rate_np = self._to_numpy(rate) noise_np = self._to_numpy(noise) self.rate = brainstate.ShortTermState(rate_np) self.noise = brainstate.ShortTermState(noise_np) dftype = brainstate.environ.dftype() self.instant_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True)) self.delayed_rate = brainstate.ShortTermState(np.array(rate_np, dtype=dftype, copy=True)) self._alloc_phi_rate(rate_np)
[docs] def update(self, x=0.0, noise=None): r"""Advance rate dynamics by one simulation step. Execute stochastic exponential Euler integration with input noise, process delayed and instant events, apply tanh nonlinearity, and optionally rectify output. Parameters ---------- x : ArrayLike, optional External current input (dimensionless). Broadcast to state shape and summed with ``mu``. Default ``0.0``. instant_rate_events : list or None, optional Rate events to apply in the current step (``delay_steps=0``). Each event can be dict, tuple, or scalar. Default None. delayed_rate_events : list or None, optional Rate events to schedule with integer delays (``delay_steps >= 0``). Default None. noise : ArrayLike or None, optional Custom noise samples :math:`\xi` drawn from :math:`\mathcal{N}(0,1)`. If None, random samples are drawn internally. Useful for reproducibility or testing. Default None. Returns ------- rate : ndarray Updated rate values with shape ``self.varshape`` (float64). Notes ----- **Update sequence:** 1. Extract current step index and state shape. 2. Drain queued delayed inputs scheduled for this step. 3. Schedule new delayed events and accumulate zero-delay events. 4. Accumulate instant events and split delta inputs by sign. 5. Draw noise (or use provided ``noise``). 6. Compute stochastic exponential Euler step: - For :math:`\lambda > 0`: use OU-exact propagators. - For :math:`\lambda = 0`: use Euler-Maruyama. 7. Apply tanh nonlinearity to summed inputs (if ``linear_summation=True``) or use pre-transformed per-event values (if ``False``). 8. Rectify output if ``rectify_output=True``. 9. Update state variables and increment step counter. """ h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms)) state_shape = self.rate.value.shape tau, sigma, mu, g, theta = self._common_parameters_tanh(state_shape) lambda_ = self._broadcast_to_state(self._to_numpy(self.lambda_), state_shape) rectify_rate = self._broadcast_to_state(self._to_numpy(self.rectify_rate), state_shape) mu_ext, h_a, h_b = self._read_coupling(x) dftype = brainstate.environ.dftype() rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape) if noise is None: xi = brainstate.random.randn(*state_shape) else: xi = jnp.broadcast_to(jnp.asarray(noise, dtype=dftype), state_shape) noise_now = sigma * xi if np.any(lambda_ > 0.0): P1 = np.exp(-lambda_ * h / tau) P2 = -np.expm1(-lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_) input_noise_factor = np.sqrt( -0.5 * np.expm1(-2.0 * lambda_ * h / tau) / np.where(lambda_ == 0.0, 1.0, lambda_) ) zero_lambda = lambda_ == 0.0 if np.any(zero_lambda): P1 = np.where(zero_lambda, 1.0, P1) P2 = np.where(zero_lambda, h / tau, P2) input_noise_factor = np.where(zero_lambda, np.sqrt(h / tau), input_noise_factor) else: P1 = np.ones_like(lambda_) P2 = h / tau input_noise_factor = np.sqrt(h / tau) mu_total = mu + mu_ext rate_new = P1 * rate_prev + P2 * mu_total + input_noise_factor * noise_now rate_new = rate_new + P2 * self._coupling_increment(rate_prev, h_a, h_b) if self.rectify_output: rate_new = jnp.where(rate_new < rectify_rate, rectify_rate, rate_new) self.rate.value = rate_new self.noise.value = noise_now self.delayed_rate.value = rate_prev self.instant_rate.value = rate_new self._store_phi_rate(rate_new) return rate_new
class tanh_rate_opn(_tanh_rate_base): r"""NEST-compatible ``tanh_rate_opn`` nonlinear rate neuron with output noise. Deterministic rate model with output-coupled additive noise and hyperbolic-tangent nonlinearity applied to network inputs, matching NEST's ``rate_neuron_opn`` template instantiated with ``tanh_rate`` gain function. **1. Model equations** The internal state :math:`X(t)` evolves deterministically .. math:: \tau\frac{dX(t)}{dt}=-X(t)+\mu+\phi(\cdot), where the input nonlinearity is .. math:: \phi(h)=\tanh(g(h-\theta)). The observed rate includes white noise added at the output: .. math:: X_\mathrm{noisy}(t)=X(t)+\sqrt{\frac{\tau}{h}}\sigma\xi(t), with :math:`\xi(t) \sim \mathcal{N}(0,1)` and :math:`h=dt` the simulation step size. The noise is scaled by :math:`\sqrt{\tau/h}` so that its variance is independent of the step size, matching NEST's implementation. **2. Numerical integration** Deterministic exponential Euler integration: .. math:: P_1 &= \exp(-h / \tau), \\ P_2 &= 1 - P_1, \\ X_{n+1} &= P_1 X_n + P_2 \left[\mu + \phi(\cdot)\right]. The noisy rate for outgoing communication is computed as .. math:: X_{\mathrm{noisy},n} = X_n + \sqrt{\frac{\tau}{h}} \sigma \xi_n. **3. Update ordering (matching NEST ``rate_neuron_opn`` with tanh)** Each simulation step proceeds as follows: 1. Draw ``noise = sigma * xi`` from the standard normal distribution. 2. Build ``noisy_rate`` from the current ``rate`` by adding scaled noise. 3. Store ``noisy_rate`` as delayed outgoing value (for delayed connections). 4. Propagate deterministic intrinsic dynamics with exponential Euler. 5. Add event-driven input contributions: - ``linear_summation=True``: apply tanh to summed branch inputs. - ``linear_summation=False``: apply tanh per event before summation. 6. Store ``noisy_rate`` as instantaneous outgoing value (for instant connections). **4. Timing semantics, assumptions, and constraints** - Noise is added to the output only (internal state :math:`X` remains deterministic). - Noise variance is independent of step size :math:`h` due to :math:`\sqrt{\tau/h}` scaling. - Multiplicative coupling factors are fixed to 1 for tanh_rate models (``mult_coupling`` has no effect; kept for NEST API compatibility). - Unlike input-noise models, there is no rectification option for output-noise models. **5. Computational implications** Per :meth:`update` call: - Random number generation: :math:`O(\prod \mathrm{varshape})`. - Exponential operations for :math:`P_1, P_2` (single exp call). - Event processing is linear in number of events per step. - Broadcasting parameters and inputs over ``self.varshape``. Parameters ---------- in_size : Size Population shape specification consumed by :class:`brainstate.nn.Dynamics`. Determines output rate array shape. tau : ArrayLike, optional Time constant :math:`\tau` of rate dynamics. Must be positive. Accepts scalar or array broadcast to ``self.varshape``. Unitful values are converted to ms; unitless are interpreted as ms. Default ``10.0 * u.ms``. sigma : ArrayLike, optional Output noise scale (dimensionless). Must be non-negative. Scales the Gaussian white noise added to the output. Broadcast to ``self.varshape``. Default ``1.0``. mu : ArrayLike, optional Mean drive :math:`\mu` (dimensionless). Constant additive input. Broadcast to ``self.varshape``. Default ``0.0``. g : ArrayLike, optional Gain of tanh nonlinearity (dimensionless). Controls steepness of tanh. Broadcast to ``self.varshape``. Default ``1.0``. theta : ArrayLike, optional Threshold (horizontal shift) of tanh nonlinearity (dimensionless). Shifts the input :math:`h` before applying tanh. Broadcast to ``self.varshape``. Default ``0.0``. mult_coupling : bool, optional Kept for NEST compatibility. For ``tanh_rate`` models, multiplicative coupling factors are identically 1, so this switch has no effect. Default ``False``. linear_summation : bool, optional Controls nonlinearity application order. If ``True``, sum inputs then apply tanh. If ``False``, apply tanh per event then sum weighted results. Default ``True``. rate_initializer : Callable, optional Initializer for state variable ``rate``. Called as ``rate_initializer(self.varshape, batch_size)`` in :meth:`init_state`. Default ``braintools.init.Constant(0.0)``. noise_initializer : Callable, optional Initializer for state variable ``noise`` (recording). Default ``braintools.init.Constant(0.0)``. noisy_rate_initializer : Callable, optional Initializer for state variable ``noisy_rate`` (output with noise). Default ``braintools.init.Constant(0.0)``. name : str or None, optional Module name passed to :class:`brainstate.nn.Dynamics`. Parameter Mapping ----------------- .. list-table:: Parameter mapping to model symbols :header-rows: 1 :widths: 22 18 22 38 * - Parameter - Default - Math symbol - Semantics * - ``tau`` - ``10.0 * u.ms`` - :math:`\tau` - Time constant of rate dynamics (ms). * - ``sigma`` - ``1.0`` - :math:`\sigma` - Output noise scale (dimensionless, >= 0). * - ``mu`` - ``0.0`` - :math:`\mu` - Constant mean drive (dimensionless). * - ``g`` - ``1.0`` - :math:`g` - Gain of tanh nonlinearity (dimensionless). * - ``theta`` - ``0.0`` - :math:`\theta` - Horizontal shift of tanh nonlinearity (dimensionless). Raises ------ ValueError If ``tau <= 0`` or ``sigma < 0``. ValueError If ``instant_rate_events`` specify non-zero ``delay_steps``, or if ``delayed_rate_events`` specify negative ``delay_steps``. See Also -------- tanh_rate_ipn : Input-noise variant of tanh_rate. sigmoid_rate_opn : Sigmoid nonlinearity with output noise. lin_rate_opn : Linear (identity) nonlinearity with output noise. gauss_rate_opn : Gaussian nonlinearity with output noise. Notes ----- **Runtime events:** - ``instant_rate_events`` : applied in the current step (``delay_steps=0``). - ``delayed_rate_events`` : scheduled with integer ``delay_steps >= 0``. **Event format** supports dict or tuple (identical to :class:`tanh_rate_ipn`): - Dict: ``{'rate': value, 'weight': w, 'delay_steps': d, 'multiplicity': m}`` - Tuple: ``(rate, weight)``, ``(rate, weight, delay_steps)``, or ``(rate, weight, delay_steps, multiplicity)`` References ---------- .. [1] NEST Simulator documentation for ``rate_neuron_opn``: https://nest-simulator.readthedocs.io/en/stable/models/rate_neuron_opn.html .. [2] Hahne, J., Dahmen, D., Schuecker, J., Frommer, A., Bolten, M., Helias, M., & Diesmann, M. (2017). Integration of continuous-time dynamics in a spiking neural network simulator. *Frontiers in Neuroinformatics*, 11, 34. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import brainunit as u >>> import jax.numpy as jnp >>> with brainstate.environ.context(dt=0.1 * u.ms): ... net = brainpy.state.tanh_rate_opn( ... in_size=10, ... tau=10.0 * u.ms, ... sigma=0.5, ... mu=0.0, ... g=1.0, ... theta=0.0, ... ) ... net.init_all_states() ... rate = net.update(x=0.1) ... _ = rate.shape # (10,) ... _ = net.noisy_rate.value.shape # (10,) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, tau: ArrayLike = 10.0 * u.ms, sigma: ArrayLike = 1.0, mu: ArrayLike = 0.0, g: ArrayLike = 1.0, theta: ArrayLike = 0.0, mult_coupling: bool = False, linear_summation: bool = True, rate_initializer: Callable = braintools.init.Constant(0.0), noise_initializer: Callable = braintools.init.Constant(0.0), noisy_rate_initializer: Callable = braintools.init.Constant(0.0), name: str = None, ): super().__init__( in_size=in_size, tau=tau, sigma=sigma, mu=mu, g=g, mult_coupling=mult_coupling, g_ex=1.0, g_in=1.0, theta_ex=0.0, theta_in=0.0, linear_summation=linear_summation, rate_initializer=rate_initializer, noise_initializer=noise_initializer, name=name, ) self.theta = braintools.init.param(theta, self.varshape) self.noisy_rate_initializer = noisy_rate_initializer self._validate_parameters() @property def recordables(self): r"""List of recordable state variables. Returns ------- list of str ``['rate', 'noise', 'noisy_rate']``. """ return ['rate', 'noise', 'noisy_rate'] @property def receptor_types(self): r"""Mapping of receptor type names to indices. Returns ------- dict ``{'RATE': 0}`` for rate-based connections. """ return {'RATE': 0} def _validate_parameters(self): r"""Validate parameter constraints at initialization. Raises ------ ValueError If ``tau <= 0`` or ``sigma < 0``. """ # Skip validation when parameters are JAX tracers (e.g. during jit). if any(is_tracer(v) for v in (self.tau, self.sigma)): return if np.any(self.tau <= 0.0 * u.ms): raise ValueError('Time constant tau must be > 0.') if np.any(self.sigma < 0.0): raise ValueError('Noise parameter sigma must be >= 0.')
[docs] def init_state(self, **kwargs): r"""Initialize state variables and internal buffers. Create ``rate``, ``noise``, ``noisy_rate``, ``instant_rate``, ``delayed_rate``, and step counter states. Initialize delay queues. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Notes ----- State variables are initialized from ``rate_initializer``, ``noise_initializer``, and ``noisy_rate_initializer``. Both ``instant_rate`` and ``delayed_rate`` are initialized as copies of ``noisy_rate``. """ rate = braintools.init.param(self.rate_initializer, self.varshape) noise = braintools.init.param(self.noise_initializer, self.varshape) noisy_rate = braintools.init.param(self.noisy_rate_initializer, self.varshape) rate_np = self._to_numpy(rate) noise_np = self._to_numpy(noise) noisy_rate_np = self._to_numpy(noisy_rate) self.rate = brainstate.ShortTermState(rate_np) self.noise = brainstate.ShortTermState(noise_np) self.noisy_rate = brainstate.ShortTermState(noisy_rate_np) dftype = brainstate.environ.dftype() self.instant_rate = brainstate.ShortTermState(np.array(noisy_rate_np, dtype=dftype, copy=True)) self.delayed_rate = brainstate.ShortTermState(np.array(noisy_rate_np, dtype=dftype, copy=True)) self._alloc_phi_rate(rate_np)
[docs] def update(self, x=0.0, noise=None): r"""Advance rate dynamics by one simulation step. Execute deterministic exponential Euler integration, add output noise, process delayed and instant events, and apply tanh nonlinearity. Parameters ---------- x : ArrayLike, optional External current input (dimensionless). Broadcast to state shape and summed with ``mu``. Default ``0.0``. instant_rate_events : list or None, optional Rate events to apply in the current step (``delay_steps=0``). Each event can be dict, tuple, or scalar. Default None. delayed_rate_events : list or None, optional Rate events to schedule with integer delays (``delay_steps >= 0``). Default None. noise : ArrayLike or None, optional Custom noise samples :math:`\xi` drawn from :math:`\mathcal{N}(0,1)`. If None, random samples are drawn internally. Useful for reproducibility or testing. Default None. Returns ------- rate : ndarray Updated deterministic rate values with shape ``self.varshape`` (float64). Note: the communicated output is ``noisy_rate``, not ``rate``. Notes ----- **Update sequence:** 1. Draw noise (or use provided ``noise``). 2. Compute ``noisy_rate`` from current ``rate`` by adding scaled noise: :math:`X_{\mathrm{noisy}} = X + \sqrt{\tau/h} \sigma \xi`. 3. Store ``noisy_rate`` as delayed outgoing value. 4. Drain queued delayed inputs scheduled for this step. 5. Schedule new delayed events and accumulate zero-delay events. 6. Accumulate instant events and split delta inputs by sign. 7. Compute deterministic exponential Euler step: :math:`X_{n+1} = P_1 X_n + P_2 [\mu + \phi(\cdot)]`. 8. Apply tanh nonlinearity to summed inputs (if ``linear_summation=True``) or use pre-transformed per-event values (if ``False``). 9. Store ``noisy_rate`` as instantaneous outgoing value. 10. Update state variables and increment step counter. """ h = float(u.get_mantissa(brainstate.environ.get_dt() / u.ms)) state_shape = self.rate.value.shape tau, sigma, mu, g, theta = self._common_parameters_tanh(state_shape) mu_ext, h_a, h_b = self._read_coupling(x) dftype = brainstate.environ.dftype() rate_prev = jnp.broadcast_to(jnp.asarray(self.rate.value, dtype=dftype), state_shape) if noise is None: xi = brainstate.random.randn(*state_shape) else: xi = jnp.broadcast_to(jnp.asarray(noise, dtype=dftype), state_shape) noise_now = sigma * xi P1 = np.exp(-h / tau) P2 = -np.expm1(-h / tau) output_noise_factor = np.sqrt(tau / h) noisy_rate = rate_prev + output_noise_factor * noise_now mu_total = mu + mu_ext rate_new = P1 * rate_prev + P2 * mu_total rate_new = rate_new + P2 * self._coupling_increment(noisy_rate, h_a, h_b) self.rate.value = rate_new self.noise.value = noise_now self.noisy_rate.value = noisy_rate self.delayed_rate.value = noisy_rate self.instant_rate.value = noisy_rate self._store_phi_rate(rate_new) return rate_new