Source code for brainpy_state._nest.pp_cond_exp_mc_urbanczik

# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable, Optional

import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from brainstate.util import DotDict

from ._base import NESTNeuron
from ._utils import is_tracer, AdaptiveRungeKuttaStep

__all__ = [
    'pp_cond_exp_mc_urbanczik',
]

# Compartment indices
SOMA = 0
DEND = 1
NCOMP = 2


class pp_cond_exp_mc_urbanczik(NESTNeuron):
    r"""Two-compartment point process neuron with conductance-based synapses for Urbanczik-Senn learning.

    ``pp_cond_exp_mc_urbanczik`` implements a two-compartment spiking neuron model
    that combines stochastic point process spike generation with dendritic prediction
    error computation for supervised learning. The soma uses conductance-based synapses
    while the dendrite uses current-based synapses. At each time step, the model
    computes a learning signal (δΠ) based on the mismatch between actual somatic
    spiking and the dendritic prediction, enabling gradient-based synaptic plasticity.

    This is a brainpy.state re-implementation of the NEST simulator model described
    in Urbanczik & Senn (2014) [1]_, using NEST-standard parameterization and
    numerical integration methods.

    Parameters
    ----------
    in_size : Size
        Population shape (tuple of ints or single int). Determines neuron array
        dimensions. Required parameter.
    t_ref : ArrayLike, optional
        Refractory period duration (Quantity, default: 3.0 ms). Neurons cannot
        spike again within this interval after a spike. If 0, no refractory period
        and Poisson spike generation is used.
    phi_max : float, optional
        Maximum firing rate in kHz (dimensionless, default: 0.15). Upper bound
        of the rate function φ(u). Typical range: 0.1-0.2 kHz (100-200 Hz).
    rate_slope : float, optional
        Rate function slope parameter ``k`` (dimensionless, default: 0.5). Controls
        the steepness of the sigmoid rate function. Must be non-negative.
    beta : float, optional
        Rate function steepness in 1/mV (dimensionless, default: 1/3 ≈ 0.333).
        Higher values create sharper transitions around threshold ``theta``.
    theta : float, optional
        Rate function threshold potential in mV (numeric, default: -55.0). Membrane
        potential at which firing rate is approximately half-maximal.
    g_sp : ArrayLike, optional
        Soma-to-dendrite coupling conductance (Quantity, default: 600.0 nS). Forward
        coupling from dendrite voltage to soma dynamics. Typically dominant coupling.
    g_ps : ArrayLike, optional
        Dendrite-to-soma coupling conductance (Quantity, default: 0.0 nS). Backward
        coupling from soma voltage to dendritic dynamics. Usually zero in this model.
    soma_g_L : ArrayLike, optional
        Somatic leak conductance (Quantity, default: 30.0 nS). Controls somatic
        resting potential and membrane time constant.
    soma_C_m : ArrayLike, optional
        Somatic membrane capacitance (Quantity, default: 300.0 pF). Together with
        leak conductance determines somatic time constant τ = C_m / g_L.
    soma_E_L : ArrayLike, optional
        Somatic leak reversal potential (Quantity, default: -70.0 mV). Resting
        potential of the soma in absence of inputs.
    soma_E_ex : ArrayLike, optional
        Somatic excitatory reversal potential (Quantity, default: 0.0 mV). Driving
        force for excitatory conductance-based synapses.
    soma_E_in : ArrayLike, optional
        Somatic inhibitory reversal potential (Quantity, default: -75.0 mV). Driving
        force for inhibitory conductance-based synapses.
    soma_tau_syn_ex : ArrayLike, optional
        Somatic excitatory synaptic time constant (Quantity, default: 3.0 ms). Decay
        time constant for excitatory conductance.
    soma_tau_syn_in : ArrayLike, optional
        Somatic inhibitory synaptic time constant (Quantity, default: 3.0 ms). Decay
        time constant for inhibitory conductance.
    soma_I_e : ArrayLike, optional
        Somatic constant external current (Quantity, default: 0.0 pA). DC bias
        current applied to soma at all times.
    dend_g_L : ArrayLike, optional
        Dendritic leak conductance (Quantity, default: 30.0 nS). Controls dendritic
        resting potential and membrane time constant.
    dend_C_m : ArrayLike, optional
        Dendritic membrane capacitance (Quantity, default: 300.0 pF). Together with
        leak conductance determines dendritic time constant τ = C_m / g_L.
    dend_E_L : ArrayLike, optional
        Dendritic leak reversal potential (Quantity, default: -70.0 mV). Resting
        potential of the dendrite in absence of inputs.
    dend_E_ex : ArrayLike, optional
        Dendritic excitatory reversal potential (Quantity, default: 0.0 mV). Used
        for documentation; current-based synapses don't use reversal potentials.
    dend_E_in : ArrayLike, optional
        Dendritic inhibitory reversal potential (Quantity, default: 0.0 mV, note:
        NOT -75.0 mV). Matches NEST default. Used for documentation only.
    dend_tau_syn_ex : ArrayLike, optional
        Dendritic excitatory synaptic time constant (Quantity, default: 3.0 ms).
        Decay time constant for excitatory current.
    dend_tau_syn_in : ArrayLike, optional
        Dendritic inhibitory synaptic time constant (Quantity, default: 3.0 ms).
        Decay time constant for inhibitory current.
    dend_I_e : ArrayLike, optional
        Dendritic constant external current (Quantity, default: 0.0 pA). DC bias
        current applied to dendrite at all times.
    gsl_error_tol : ArrayLike, optional
        Unitless local RKF45 error tolerance (default: 1e-3). Must be strictly
        positive.
    rng_key : jax.Array, optional
        JAX PRNG key for stochastic spike generation (default: None). If None, a
        default key (PRNGKey(0)) is used. For reproducibility, provide explicit key.
    spk_fun : Callable, optional
        Surrogate gradient function for spike differentiation (default:
        braintools.surrogate.ReluGrad()). Used in backpropagation through spikes.
    spk_reset : str, optional
        Spike reset mode (default: 'hard'). Options: 'hard' (stop_gradient) or
        'soft' (subtract threshold). Note: This model has NO voltage reset after
        spikes, so this parameter has limited effect.
    name : str, optional
        Module name (default: None). Used for logging and identification.

    Parameter Mapping
    -----------------
    This table maps NEST C++ parameter names to brainpy.state constructor arguments:

    ================================ =========================== ===============
    **NEST Parameter**               **brainpy.state Parameter** **Default**
    ================================ =========================== ===============
    ``t_ref``                        ``t_ref``                   3.0 ms
    ``phi_max``                      ``phi_max``                 0.15 kHz
    ``rate_slope``                   ``rate_slope``              0.5
    ``beta``                         ``beta``                    0.333 (1/mV)
    ``theta``                        ``theta``                   -55.0 mV
    ``g_sp``                         ``g_sp``                    600.0 nS
    ``g_ps``                         ``g_ps``                    0.0 nS
    ``g_L`` (soma)                   ``soma_g_L``                30.0 nS
    ``C_m`` (soma)                   ``soma_C_m``                300.0 pF
    ``E_L`` (soma)                   ``soma_E_L``                -70.0 mV
    ``E_ex`` (soma)                  ``soma_E_ex``               0.0 mV
    ``E_in`` (soma)                  ``soma_E_in``               -75.0 mV
    ``tau_syn_ex`` (soma)            ``soma_tau_syn_ex``         3.0 ms
    ``tau_syn_in`` (soma)            ``soma_tau_syn_in``         3.0 ms
    ``I_e`` (soma)                   ``soma_I_e``                0.0 pA
    ``g_L`` (dendrite)               ``dend_g_L``                30.0 nS
    ``C_m`` (dendrite)               ``dend_C_m``                300.0 pF
    ``E_L`` (dendrite)               ``dend_E_L``                -70.0 mV
    ``E_ex`` (dendrite)              ``dend_E_ex``               0.0 mV
    ``E_in`` (dendrite)              ``dend_E_in``               0.0 mV
    ``tau_syn_ex`` (dendrite)        ``dend_tau_syn_ex``         3.0 ms
    ``tau_syn_in`` (dendrite)        ``dend_tau_syn_in``         3.0 ms
    ``I_e`` (dendrite)               ``dend_I_e``                0.0 pA
    ================================ =========================== ===============

    Mathematical Formulation
    ------------------------
    **1. Compartment Structure**

    The neuron consists of two compartments:

    * **Soma (s):** Conductance-based synapses, stochastic spike generation
    * **Dendrite (d, also labeled p for "proximal"):** Current-based synapses,
      predictive signal for learning

    **2. Somatic Dynamics**

    The somatic membrane potential evolves according to:

    .. math::

        C_\mathrm{m}^s \frac{dV^s}{dt} = -g_\mathrm{L}^s (V^s - E_\mathrm{L}^s)
            - g_\mathrm{ex}^s (V^s - E_\mathrm{ex}^s)
            - g_\mathrm{in}^s (V^s - E_\mathrm{in}^s)
            + g_\mathrm{sp} (V^p - V^s)
            + I_\mathrm{stim}^s + I_\mathrm{e}^s

    Somatic synaptic conductances decay exponentially:

    .. math::

        \frac{dg_\mathrm{ex}^s}{dt} = -\frac{g_\mathrm{ex}^s}{\tau_\mathrm{syn,ex}^s},
        \qquad
        \frac{dg_\mathrm{in}^s}{dt} = -\frac{g_\mathrm{in}^s}{\tau_\mathrm{syn,in}^s}

    **3. Dendritic Dynamics**

    The dendritic membrane potential evolves according to:

    .. math::

        C_\mathrm{m}^p \frac{dV^p}{dt} = -g_\mathrm{L}^p (V^p - E_\mathrm{L}^p)
            + I_\mathrm{ex}^p + I_\mathrm{in}^p
            + g_\mathrm{ps} (V^s - V^p)

    Dendritic synaptic currents (note: **current-based**, not conductance) decay
    exponentially:

    .. math::

        \frac{dI_\mathrm{ex}^p}{dt} = -\frac{I_\mathrm{ex}^p}{\tau_\mathrm{syn,ex}^p},
        \qquad
        \frac{dI_\mathrm{in}^p}{dt} = -\frac{I_\mathrm{in}^p}{\tau_\mathrm{syn,in}^p}

    **4. Stochastic Spike Generation**

    Spikes are generated stochastically based on the instantaneous rate function:

    .. math::

        \text{rate}(t) = 1000 \cdot \phi(V^s(t)) \quad [\text{Hz}]

    where:

    .. math::

        \phi(u) = \frac{\phi_\mathrm{max}}{1 + k \cdot \exp(\beta (\theta - u))}

    * **With refractory period** (``t_ref > 0``): At most one spike per time step.
      Spike probability is :math:`P_{\mathrm{spike}} = 1 - \exp(-\text{rate} \cdot dt \cdot 10^{-3})`.
      A uniform random number :math:`r \sim U(0,1)` is compared to this probability.
    * **Without refractory period** (``t_ref == 0``): Number of spikes drawn from
      Poisson distribution with mean :math:`\lambda = \text{rate} \cdot dt \cdot 10^{-3}`.

    **Important:** There is **NO membrane potential reset** after a spike. The voltage
    continues to evolve according to the differential equations.

    **5. Urbanczik-Senn Learning Signal**

    At each time step, the model computes a learning signal for synaptic plasticity.
    The dendritic compartment predicts the somatic potential via:

    .. math::

        V^*_W = \frac{E_\mathrm{L}^s \cdot g_\mathrm{L}^s + V^p \cdot g_\mathrm{sp}}{g_\mathrm{sp} + g_\mathrm{L}^s}

    This represents the steady-state somatic voltage given the current dendritic
    voltage, assuming all synaptic inputs are zero.

    The error signal (prediction error) at time step :math:`t` is:

    .. math::

        \delta\Pi(t) = \left(n_\mathrm{spikes}(t) - \phi(V^*_W(t)) \cdot dt\right) \cdot h(V^*_W(t))

    where:

    * :math:`n_{\mathrm{spikes}}(t)` is the number of actual spikes emitted (0 or 1
      with refractory period, ≥0 without)
    * :math:`\phi(V^*_W(t)) \cdot dt` is the expected spike count based on prediction
    * :math:`h(u)` is the learning modulation function:

    .. math::

        h(u) = \frac{15 \cdot \beta}{1 + \frac{1}{k} \cdot \exp(-\beta (\theta - u))}

    The history of :math:`(t, \delta\Pi)` pairs is stored and accessible via
    ``get_urbanczik_history()`` for use by plasticity rules.

    **6. Receptor Types and Synaptic Input Addressing**

    Synaptic inputs are routed to specific compartments and receptor types via
    labeled input channels:

    =================== ====== ============================================
    Receptor Label       Port   Description
    =================== ====== ============================================
    ``soma_exc``         1      Excitatory conductance input to soma (nS)
    ``soma_inh``         2      Inhibitory conductance input to soma (nS)
    ``dend_exc``         3      Excitatory current input to dendrite (pA)
    ``dend_inh``         4      Inhibitory current input to dendrite (pA)
    ``soma`` (current)   5      Direct current injection to soma (pA)
    ``dend`` (current)   6      Direct current injection to dendrite (pA)
    =================== ====== ============================================

    **Implementation Note:** In brainpy.state, use ``add_delta_input()`` with labels
    ``'soma_exc'``, ``'soma_inh'``, ``'dend_exc'``, ``'dend_inh'`` for synaptic
    spikes. Use ``add_current_input()`` with labels ``'soma'`` and ``'dend'`` for
    current injections. All synaptic weights must be **positive**; excitation vs.
    inhibition is determined by the receptor label.

    **7. Numerical Integration**

    The 6-dimensional ODE system (V_s, g_ex_s, g_in_s, V_d, I_ex_d, I_in_d) is
    integrated using an adaptive RKF45 Runge-Kutta-Fehlberg integrator that is
    fully JAX-compatible and differentiable.

    **Update Order per Time Step:**

    1. Integrate ODEs over interval :math:`(t, t + dt]` using current stimulus
       currents from previous step
    2. Add arriving synaptic spike inputs (conductance/current jumps):

       * Soma: :math:`g_{\mathrm{ex}}^s \mathrel{+}= \Delta g_{\mathrm{ex}}`,
         :math:`g_{\mathrm{in}}^s \mathrel{+}= \Delta g_{\mathrm{in}}`
       * Dendrite: :math:`I_{\mathrm{ex}}^p \mathrel{+}= \Delta I_{\mathrm{ex}}`,
         :math:`I_{\mathrm{in}}^p \mathrel{-}= \Delta I_{\mathrm{in}}` (note sign)

    3. Check refractoriness and generate spikes stochastically if not refractory
    4. Compute and store Urbanczik learning signal :math:`\delta\Pi`
    5. Buffer external current inputs for next time step

    Computational Complexity and Performance
    ----------------------------------------
    **Time Complexity:** :math:`O(N \cdot S)` where :math:`N` is the number of neurons
    and :math:`S` is the number of adaptive ODE solver steps per neuron per time step.
    Typically :math:`S \approx 3-10` depending on dynamics.

    **Space Complexity:** :math:`O(N)` for state variables, plus :math:`O(N \cdot T)`
    for Urbanczik history over :math:`T` time steps.

    **Performance Notes:**

    * This model is **significantly slower** than simple LIF neurons due to:
      (1) element-wise adaptive ODE solving per neuron, (2) stochastic spike
      generation requiring RNG calls, and (3) learning signal computation.
    * Not vectorized across neurons; uses Python loop over ``np.ndindex``.
    * For large networks (>1000 neurons), consider alternative implementations or
      simplified models.
    * History storage grows linearly with simulation time; clear periodically if
      memory is constrained.

    Attributes (State Variables)
    -----------------------------
    V_s : brainstate.HiddenState
        Somatic membrane potential (Quantity, shape: ``varshape``).
        Initialized to ``soma_E_L``. Unit: mV.
    g_ex_s : brainstate.HiddenState
        Somatic excitatory synaptic conductance (Quantity). Initialized to 0. Unit: nS.
    g_in_s : brainstate.HiddenState
        Somatic inhibitory synaptic conductance (Quantity). Initialized to 0. Unit: nS.
    V_d : brainstate.HiddenState
        Dendritic membrane potential (Quantity). Initialized to ``dend_E_L``. Unit: mV.
    I_ex_d : brainstate.HiddenState
        Dendritic excitatory synaptic current (Quantity). Initialized to 0. Unit: pA.
    I_in_d : brainstate.HiddenState
        Dendritic inhibitory synaptic current (Quantity). Initialized to 0. Unit: pA.
    refractory_step_count : brainstate.ShortTermState
        Remaining refractory time steps (int32 array). Counts down to zero. Initialized to 0.
    I_stim_soma : brainstate.ShortTermState
        Buffered soma current for next integration step (Quantity). Unit: pA.
    I_stim_dend : brainstate.ShortTermState
        Buffered dendrite current for next integration step (Quantity). Unit: pA.
    last_spike_time : brainstate.ShortTermState
        Time of last spike emission (Quantity). Initialized to -1e7 ms. Unit: ms.
    integration_step : brainstate.ShortTermState
        Persistent RKF45 substep size estimate (ms).

    Raises
    ------
    ValueError
        If ``rate_slope < 0`` (must be non-negative).
    ValueError
        If ``phi_max < 0`` (must be non-negative).
    ValueError
        If ``t_ref < 0`` (must be non-negative).
    ValueError
        If any capacitance ``C_m`` ≤ 0 (must be strictly positive).
    ValueError
        If any synaptic time constant ≤ 0 (must be strictly positive).
    ValueError
        If ``gsl_error_tol`` ≤ 0 (must be strictly positive).

    Notes
    -----
    * **NEST Compatibility:** All default parameter values match NEST 3.9+ C++ source
      for ``pp_cond_exp_mc_urbanczik``. Notable: dendritic inhibitory reversal
      potential is 0.0 mV (not -75.0 mV).
    * **Stochasticity:** Spike times are non-deterministic unless ``rng_key`` is
      explicitly managed. For reproducibility, provide a fixed PRNG key and re-seed
      appropriately.
    * **No Voltage Reset:** Unlike integrate-and-fire models, there is no discrete
      voltage reset after spiking. The membrane potential evolves continuously.
    * **Urbanczik History:** The learning signal history is stored in a Python dict
      (``_urbanczik_history``) and grows unbounded. For long simulations, periodically
      clear history or implement custom storage.
    * **Surrogate Gradients:** The ``spk_fun`` parameter enables gradient-based
      learning through spike discontinuities, but this model is primarily designed
      for the Urbanczik-Senn rule which uses the stored δΠ signals directly.

    Examples
    --------
    **Basic single neuron simulation:**

    .. code-block:: python

        >>> import brainpy.state as bp
        >>> import saiunit as u
        >>> import brainstate
        >>> import numpy as np
        >>> # Create a single neuron
        >>> neuron = bp.pp_cond_exp_mc_urbanczik(in_size=1)
        >>> neuron.init_all_states()
        >>> # Simulate for 100 ms with constant soma current
        >>> dt = 0.1 * u.ms
        >>> with brainstate.environ.context(dt=dt):
        ...     spikes = []
        ...     for i in range(1000):  # 100 ms
        ...         spk = neuron.update(x=300.0 * u.pA)  # Strong depolarizing current
        ...         spikes.append(float(spk[0]))
        >>> print(f"Total spikes: {sum(spikes)}")
        Total spikes: 12

    **Two-compartment voltage monitoring:**

    .. code-block:: python

        >>> neuron = bp.pp_cond_exp_mc_urbanczik(in_size=1, t_ref=0.0*u.ms)
        >>> neuron.init_all_states()
        >>> soma_voltages, dend_voltages = [], []
        >>> with brainstate.environ.context(dt=0.1*u.ms):
        ...     for i in range(500):
        ...         neuron.update(x=200.0 * u.pA)
        ...         soma_voltages.append(float(neuron.V_s.value[0] / u.mV))
        ...         dend_voltages.append(float(neuron.V_d.value[0] / u.mV))
        >>> # Plot soma_voltages and dend_voltages to visualize dynamics

    **Accessing Urbanczik learning signals:**

    .. code-block:: python

        >>> neuron = bp.pp_cond_exp_mc_urbanczik(in_size=2)
        >>> neuron.init_all_states()
        >>> with brainstate.environ.context(dt=0.1*u.ms):
        ...     for i in range(100):
        ...         neuron.update(x=250.0 * u.pA)
        >>> # Retrieve learning signal history for neuron 0
        >>> history_0 = neuron.get_urbanczik_history(neuron_idx=0)
        >>> print(f"History length: {len(history_0)}")
        History length: 100
        >>> # Each entry is (time_ms, delta_PI)
        >>> t, dPI = history_0[-1]
        >>> print(f"Last time: {t:.2f} ms, Last dPI: {dPI:.4f}")
        Last time: 10.00 ms, Last dPI: -0.0234

    References
    ----------
    .. [1] Urbanczik R, Senn W (2014). Learning by the Dendritic Prediction of
           Somatic Spiking. Neuron, 81(3):521-528.
           DOI: https://doi.org/10.1016/j.neuron.2013.11.030

    .. [2] NEST Simulator ``pp_cond_exp_mc_urbanczik`` model documentation:
           https://nest-simulator.readthedocs.io/en/stable/models/pp_cond_exp_mc_urbanczik.html

    .. [3] NEST C++ source code: ``models/pp_cond_exp_mc_urbanczik.h`` and
           ``models/pp_cond_exp_mc_urbanczik.cpp`` in NEST 3.9+ distribution.

    See Also
    --------
    gif_cond_exp : Generalized integrate-and-fire with conductance synapses
    pp_psc_delta : Point process neuron with current synapses
    urbanczik_synapse : Synapse model implementing Urbanczik-Senn plasticity rule
    """
    __module__ = 'brainpy.state'

    _MIN_H = 1e-8 * u.ms  # ms
    _MAX_ITERS = 100000

    def __init__(
        self,
        in_size: Size,
        # Global parameters
        t_ref: ArrayLike = 3.0 * u.ms,
        phi_max: float = 0.15,  # kHz
        rate_slope: float = 0.5,  # dimensionless
        beta: float = 1.0 / 3.0,  # 1/mV
        theta: float = -55.0,  # mV
        g_sp: ArrayLike = 600.0 * u.nS,  # soma-dendrite coupling
        g_ps: ArrayLike = 0.0 * u.nS,  # dendrite-soma coupling
        # Soma compartment parameters
        soma_g_L: ArrayLike = 30.0 * u.nS,
        soma_C_m: ArrayLike = 300.0 * u.pF,
        soma_E_L: ArrayLike = -70.0 * u.mV,
        soma_E_ex: ArrayLike = 0.0 * u.mV,
        soma_E_in: ArrayLike = -75.0 * u.mV,
        soma_tau_syn_ex: ArrayLike = 3.0 * u.ms,
        soma_tau_syn_in: ArrayLike = 3.0 * u.ms,
        soma_I_e: ArrayLike = 0.0 * u.pA,
        # Dendritic compartment parameters
        dend_g_L: ArrayLike = 30.0 * u.nS,
        dend_C_m: ArrayLike = 300.0 * u.pF,
        dend_E_L: ArrayLike = -70.0 * u.mV,
        dend_E_ex: ArrayLike = 0.0 * u.mV,
        dend_E_in: ArrayLike = 0.0 * u.mV,
        dend_tau_syn_ex: ArrayLike = 3.0 * u.ms,
        dend_tau_syn_in: ArrayLike = 3.0 * u.ms,
        dend_I_e: ArrayLike = 0.0 * u.pA,
        # Integration tolerance
        gsl_error_tol: ArrayLike = 1e-3,
        # RNG and surrogate
        rng_key: Optional[jax.Array] = None,
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        # Global parameters
        self.t_ref = braintools.init.param(t_ref, self.varshape)
        self.phi_max = float(phi_max)
        self.rate_slope = float(rate_slope)
        self.beta = float(beta)
        self.theta = float(theta)
        self.g_sp = braintools.init.param(g_sp, self.varshape)
        self.g_ps = braintools.init.param(g_ps, self.varshape)

        # Soma parameters
        self.soma_g_L = braintools.init.param(soma_g_L, self.varshape)
        self.soma_C_m = braintools.init.param(soma_C_m, self.varshape)
        self.soma_E_L = braintools.init.param(soma_E_L, self.varshape)
        self.soma_E_ex = braintools.init.param(soma_E_ex, self.varshape)
        self.soma_E_in = braintools.init.param(soma_E_in, self.varshape)
        self.soma_tau_syn_ex = braintools.init.param(soma_tau_syn_ex, self.varshape)
        self.soma_tau_syn_in = braintools.init.param(soma_tau_syn_in, self.varshape)
        self.soma_I_e = braintools.init.param(soma_I_e, self.varshape)

        # Dendritic parameters
        self.dend_g_L = braintools.init.param(dend_g_L, self.varshape)
        self.dend_C_m = braintools.init.param(dend_C_m, self.varshape)
        self.dend_E_L = braintools.init.param(dend_E_L, self.varshape)
        self.dend_E_ex = braintools.init.param(dend_E_ex, self.varshape)
        self.dend_E_in = braintools.init.param(dend_E_in, self.varshape)
        self.dend_tau_syn_ex = braintools.init.param(dend_tau_syn_ex, self.varshape)
        self.dend_tau_syn_in = braintools.init.param(dend_tau_syn_in, self.varshape)
        self.dend_I_e = braintools.init.param(dend_I_e, self.varshape)

        # Integration tolerance
        self.gsl_error_tol = gsl_error_tol

        # RNG
        self._rng_key = rng_key

        self._validate_parameters()

        self.integrator = AdaptiveRungeKuttaStep(
            method='RKF45',
            vf=self._vector_field,
            event_fn=self._event_fn,
            min_h=self._MIN_H,
            max_iters=self._MAX_ITERS,
            atol=self.gsl_error_tol,
            dt=brainstate.environ.get_dt()
        )

        # Precompute refractory step count
        ditype = brainstate.environ.ditype()
        dt = brainstate.environ.get_dt()
        self.ref_count = u.math.asarray(u.math.ceil(self.t_ref / dt), dtype=ditype)

    def _validate_parameters(self):
        r"""Validate model parameters against NEST constraints.

        Raises
        ------
        ValueError
            If parameter inequalities or positivity constraints are violated.
        """
        # Skip validation when parameters are JAX tracers (e.g. during jit).
        if any(is_tracer(v) for v in (self.soma_C_m, self.t_ref)):
            return

        if self.rate_slope < 0:
            raise ValueError('Rate slope cannot be negative.')
        if self.phi_max < 0:
            raise ValueError('Maximum rate cannot be negative.')
        if np.any(self.t_ref < 0.0 * u.ms):
            raise ValueError('Refractory time cannot be negative.')
        for label, C_m in [('soma', self.soma_C_m), ('dendritic', self.dend_C_m)]:
            if np.any(C_m <= 0.0 * u.pF):
                raise ValueError(f'Capacitance ({label}) must be strictly positive.')
        for label, tse, tsi in [
            ('soma', self.soma_tau_syn_ex, self.soma_tau_syn_in),
            ('dendritic', self.dend_tau_syn_ex, self.dend_tau_syn_in),
        ]:
            if np.any(tse <= 0.0 * u.ms) or np.any(tsi <= 0.0 * u.ms):
                raise ValueError('All time constants must be strictly positive.')
        if np.any(self.gsl_error_tol <= 0.0):
            raise ValueError('The gsl_error_tol must be strictly positive.')

[docs] def init_state(self, **kwargs): r"""Initialize persistent and short-term state variables. Parameters ---------- **kwargs Unused compatibility parameters accepted by the base-state API. Raises ------ ValueError If an initializer cannot be broadcast to requested shape. TypeError If initializer outputs have incompatible units/dtypes for the corresponding state variables. """ ditype = brainstate.environ.ditype() dftype = brainstate.environ.dftype() dt = brainstate.environ.get_dt() # Membrane potentials initialized to E_L self.V_s = brainstate.HiddenState( u.math.ones(self.varshape, dtype=dftype) * self.soma_E_L ) self.V_d = brainstate.HiddenState( u.math.ones(self.varshape, dtype=dftype) * self.dend_E_L ) # Somatic conductances self.g_ex_s = brainstate.HiddenState( u.math.zeros(self.varshape, dtype=dftype) * u.nS ) self.g_in_s = brainstate.HiddenState( u.math.zeros(self.varshape, dtype=dftype) * u.nS ) # Dendritic currents self.I_ex_d = brainstate.HiddenState( u.math.zeros(self.varshape, dtype=dftype) * u.pA ) self.I_in_d = brainstate.HiddenState( u.math.zeros(self.varshape, dtype=dftype) * u.pA ) # Refractory counter self.refractory_step_count = brainstate.ShortTermState( u.math.full(self.varshape, 0, dtype=ditype) ) # Buffered stimulus currents (per compartment) self.I_stim_soma = brainstate.ShortTermState( u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype) ) self.I_stim_dend = brainstate.ShortTermState( u.math.full(self.varshape, 0.0 * u.pA, dtype=dftype) ) # Last spike time self.last_spike_time = brainstate.ShortTermState( u.math.full(self.varshape, -1e7 * u.ms) ) # Integration step size self.integration_step = brainstate.ShortTermState.init( braintools.init.Constant(dt), self.varshape ) # Urbanczik history: list of (t_ms, dPI) tuples per neuron element # (populated only during Python-loop execution, not inside for_loop / JIT) self._urbanczik_history = {} # Current-step dPI stored as ShortTermState so for_loop bodies can # return it and collect the full trace. self._dPI = brainstate.ShortTermState( jnp.zeros(self.varshape, dtype=dftype) ) # RNG state as ShortTermState so jax.lax.scan tracks it correctly. rng_init = self._rng_key if self._rng_key is not None else jax.random.PRNGKey(0) self._rng_state = brainstate.ShortTermState(rng_init)
[docs] def get_spike(self, V: ArrayLike = None): r"""Evaluate surrogate spike output from membrane voltage. Parameters ---------- V : ArrayLike, optional Voltage values with shape broadcastable to ``self.varshape`` and units compatible with mV. If ``None``, uses current state ``self.V_s.value``. Returns ------- ArrayLike Surrogate spike activation produced by ``spk_fun(V / (1.0 * u.mV))``. """ V = self.V_s.value if V is None else V v_scaled = V / (1.0 * u.mV) return self.spk_fun(v_scaled)
def _collect_receptor_delta_inputs(self): r"""Collect delta inputs labeled by receptor type. Expected labels: 'soma_exc', 'soma_inh', 'dend_exc', 'dend_inh'. Returns ------- soma_exc, soma_inh, dend_exc, dend_inh : Quantity arrays (nS or pA) """ v_shape = self.V_s.value.shape soma_exc = u.math.zeros(v_shape) * u.nS soma_inh = u.math.zeros(v_shape) * u.nS dend_exc = u.math.zeros(v_shape) * u.pA dend_inh = u.math.zeros(v_shape) * u.pA if self.delta_inputs is None: return soma_exc, soma_inh, dend_exc, dend_inh for key in tuple(self.delta_inputs.keys()): out = self.delta_inputs[key] if callable(out): out = out() else: self.delta_inputs.pop(key) label = key if isinstance(key, str) else str(key) if 'soma_exc' in label: soma_exc = soma_exc + out elif 'soma_inh' in label: soma_inh = soma_inh + out elif 'dend_exc' in label: dend_exc = dend_exc + out elif 'dend_inh' in label: dend_inh = dend_inh + out return soma_exc, soma_inh, dend_exc, dend_inh def _vector_field(self, state, extra): """Unit-aware vectorized RHS for all neurons simultaneously. Parameters ---------- state : DotDict Keys: V_s, g_ex_s, g_in_s, V_d, I_ex_d, I_in_d — ODE state variables. extra : DotDict Keys: spike_mask, r, i_stim_soma — mutable auxiliary data carried through the integrator. Returns ------- DotDict with same keys as ``state``, containing time derivatives. """ # Soma dynamics V_s = state.V_s V_d = state.V_d # Soma leak current I_L_s = self.soma_g_L * (V_s - self.soma_E_L) # Soma excitatory synaptic current (conductance-based) I_syn_exc = state.g_ex_s * (V_s - self.soma_E_ex) # Soma inhibitory synaptic current (conductance-based) I_syn_inh = state.g_in_s * (V_s - self.soma_E_in) # Coupling current dendrite -> soma I_conn_d_s = self.g_sp * (V_d - V_s) # Coupling current soma -> dendrite I_conn_s_d = self.g_ps * (V_s - V_d) # Soma membrane potential derivative dV_s = ( -I_L_s - I_syn_exc - I_syn_inh + I_conn_d_s + extra.i_stim_soma + self.soma_I_e ) / self.soma_C_m # Soma conductance derivatives dg_ex_s = -state.g_ex_s / self.soma_tau_syn_ex dg_in_s = -state.g_in_s / self.soma_tau_syn_in # Dendrite membrane potential derivative dV_d = ( -self.dend_g_L * (V_d - self.dend_E_L) + state.I_ex_d + state.I_in_d + I_conn_s_d ) / self.dend_C_m # Dendrite current derivatives dI_ex_d = -state.I_ex_d / self.dend_tau_syn_ex dI_in_d = -state.I_in_d / self.dend_tau_syn_in return DotDict( V_s=dV_s, g_ex_s=dg_ex_s, g_in_s=dg_in_s, V_d=dV_d, I_ex_d=dI_ex_d, I_in_d=dI_in_d, ) def _event_fn(self, state, extra, accept): """In-loop event callback for the adaptive integrator. This model does not perform spike detection or voltage reset inside the integration loop (spikes are stochastic and generated after integration). The event function is a no-op pass-through. Parameters ---------- state : DotDict Keys: V_s, g_ex_s, g_in_s, V_d, I_ex_d, I_in_d — ODE state variables. extra : DotDict Keys: spike_mask, r, i_stim_soma. accept : array, bool Mask of neurons whose RK substep was accepted. Returns ------- (new_state, new_extra) DotDicts unchanged. """ return state, extra
[docs] def update(self, x=0.0 * u.pA): r"""Advance neuron state by one simulation time step with ODE integration and stochastic spiking. This method performs a complete update cycle for the two-compartment model, including numerical integration of differential equations, synaptic input processing, stochastic spike generation, and Urbanczik learning signal computation. It follows NEST's update order exactly. Parameters ---------- x : Quantity, optional External current input applied to the soma compartment (unit: pA, default: 0.0 pA). This is typically used for direct current injection from external sources. Shape must be broadcastable to neuron population shape. Returns ------- spike : jax.numpy.ndarray Binary spike output array of shape matching neuron population (float32). Values are 1.0 where a spike was emitted, 0.0 otherwise. For neurons without refractory period (t_ref=0), values can be >1 if multiple spikes occurred in one step. Notes ----- Update Procedure ---------------- The method executes the following steps in order (per neuron): **1. ODE Integration** Integrate the 6-dimensional state vector over the interval (t, t+dt] using the adaptive RKF45 solver. The integration uses stimulus currents buffered from the previous time step. **2. Synaptic Input Application** Apply instantaneous jumps to synaptic variables from arriving spikes: * Soma: g_ex_s += Δg_ex, g_in_s += Δg_in (conductance jumps in nS) * Dendrite: I_ex_d += ΔI_ex, I_in_d -= ΔI_in (current jumps in pA; note sign) **3. Stochastic Spike Generation** If neuron is not refractory: * Compute instantaneous rate: rate = 1000 · φ(V_s) [Hz] * With t_ref > 0: Draw uniform random r, emit spike if r ≤ 1 - exp(-rate·dt·1e-3) * With t_ref = 0: Draw Poisson(rate·dt·1e-3) for spike count * If spike(s) emitted: set refractory counter to round(t_ref / dt) If neuron is refractory: decrement refractory counter, no spikes. **4. Urbanczik Learning Signal** Compute dendritic prediction and error signal: * :math:`V^*_W = (E_{L,s} \cdot g_{L,s} + V_d \cdot g_{sp}) / (g_{sp} + g_{L,s})` * :math:`\delta\Pi = (n_{\text{spikes}} - \phi(V^*_W) \cdot dt) \cdot h(V^*_W)` * Store :math:`(t, \delta\Pi)` in history dict **5. Current Input Buffering** Collect all current inputs (via ``sum_current_inputs()``) and store for use in the next time step's ODE integration. Computational Complexity ------------------------ * Time: O(N · S) where N is population size, S is adaptive ODE steps per neuron * Space: O(N) for state updates, O(N·T) for history accumulation over T steps * **Not vectorized:** Uses Python loop over all neuron indices Side Effects ------------ * Updates all state variables (V_s, V_d, g_ex_s, g_in_s, I_ex_d, I_in_d) * Updates refractory counters and last_spike_time * Appends (t, δΠ) to internal ``_urbanczik_history`` dict * Advances internal PRNG state (``_rng_state``) * Consumes and clears delta_inputs from projections Numerical Considerations ------------------------ * The ODE solver is adaptive and may take variable numbers of internal steps * For stiff dynamics or large coupling conductances, integration may require more steps, increasing computation time * Dendritic inhibitory current is **subtracted**, matching NEST convention for inhibitory synapses .. warning:: Numerical issues (NaN, Inf) can arise from invalid parameter combinations (e.g., zero capacitance), extremely large input currents, or ODE solver failure. Notes ----- * This is a **slow** model due to per-neuron ODE solving and lack of vectorization. For networks >1000 neurons, expect significant runtime. * The lack of voltage reset after spikes is intentional and matches the original Urbanczik & Senn (2014) formulation. * Random number generation state is advanced even if no spikes occur, ensuring reproducibility across different input patterns given the same seed. """ t = brainstate.environ.get('t') dt = brainstate.environ.get_dt() dftype = brainstate.environ.dftype() ditype = brainstate.environ.ditype() # dt_ms as a JAX scalar (concrete when dt is a static environ value). dt_ms = u.get_mantissa(dt / u.ms) v_shape = self.varshape # Read state variables with their natural units. V_s = self.V_s.value # mV g_ex_s = self.g_ex_s.value # nS g_in_s = self.g_in_s.value # nS V_d = self.V_d.value # mV I_ex_d = self.I_ex_d.value # pA I_in_d = self.I_in_d.value # pA r = self.refractory_step_count.value # int i_stim_soma = self.I_stim_soma.value # pA h = self.integration_step.value # ms # Current input for next step (one-step delay). new_i_stim_soma = self.sum_current_inputs(x, self.V_s.value) # pA # Adaptive RKF45 integration via generic integrator. ode_state = DotDict( V_s=V_s, g_ex_s=g_ex_s, g_in_s=g_in_s, V_d=V_d, I_ex_d=I_ex_d, I_in_d=I_in_d, ) extra = DotDict( spike_mask=jnp.zeros(v_shape, dtype=jnp.bool_), r=r, i_stim_soma=i_stim_soma, ) ode_state, h, extra = self.integrator(state=ode_state, h=h, extra=extra) V_s = ode_state.V_s g_ex_s = ode_state.g_ex_s g_in_s = ode_state.g_in_s V_d = ode_state.V_d I_ex_d = ode_state.I_ex_d I_in_d = ode_state.I_in_d r = extra.r # Collect synaptic spike inputs d_soma_exc, d_soma_inh, d_dend_exc, d_dend_inh = self._collect_receptor_delta_inputs() # Apply synaptic spike inputs (after integration). g_ex_s = g_ex_s + d_soma_exc g_in_s = g_in_s + d_soma_inh I_ex_d = I_ex_d + d_dend_exc I_in_d = I_in_d - d_dend_inh # Note: inhibitory is subtracted (NEST convention) # --- Vectorized stochastic spike generation (fully JAX, JIT-compatible) --- # Advance RNG via ShortTermState so jax.lax.scan tracks the key. new_rng, subkey = jax.random.split(self._rng_state.value) self._rng_state.value = new_rng # Instantaneous firing rate for all neurons (Hz). V_s_mv = u.get_mantissa(V_s / u.mV) rate = 1000.0 * self.phi_max / ( 1.0 + self.rate_slope * jnp.exp( jnp.clip(self.beta * (self.theta - V_s_mv), -500.0, 500.0) ) ) not_refractory = (r == 0) # Dead-time flag per neuron (True when t_ref > 0). t_ref_ms = u.get_mantissa(self.t_ref / u.ms) has_dead_time = jnp.broadcast_to(t_ref_ms > 0.0, v_shape) # --- Dead-time mode: at most 1 spike per step --- rand_vals = jax.random.uniform(subkey, shape=v_shape, dtype=dftype) spike_prob = -jnp.expm1(-rate * dt_ms * 1e-3) has_spike_dead = rand_vals <= spike_prob # --- Poisson mode: draw spike count from Poisson distribution --- subkey_p, _ = jax.random.split(subkey) lam = rate * dt_ms * 1e-3 n_spikes_poisson = jax.random.poisson(subkey_p, lam, shape=v_shape, dtype=ditype) has_spike_poisson = n_spikes_poisson > 0 # Select spike event based on dead-time flag. spike_now_if_active = jnp.where(has_dead_time, has_spike_dead, has_spike_poisson) spike_mask = not_refractory & (rate > 0.0) & spike_now_if_active # Spike count (float) for dPI formula. n_spikes_float = jnp.where( spike_mask, jnp.where(has_dead_time, jnp.ones(v_shape, dtype=dftype), n_spikes_poisson.astype(dftype)), jnp.zeros(v_shape, dtype=dftype), ) # Update refractory counter. new_r = jnp.where( spike_mask, jnp.broadcast_to(u.get_mantissa(self.ref_count), v_shape), jnp.maximum(0, r - 1), ) # --- Urbanczik learning signal (fully vectorized) --- V_d_mv = u.get_mantissa(V_d / u.mV) g_sp_nS = jnp.broadcast_to(u.get_mantissa(self.g_sp / u.nS), v_shape) g_L_s_nS = jnp.broadcast_to(u.get_mantissa(self.soma_g_L / u.nS), v_shape) E_L_s_mV = jnp.broadcast_to(u.get_mantissa(self.soma_E_L / u.mV), v_shape) V_W_star = (E_L_s_mV * g_L_s_nS + V_d_mv * g_sp_nS) / (g_sp_nS + g_L_s_nS) phi_val = self.phi_max / ( 1.0 + self.rate_slope * jnp.exp( jnp.clip(self.beta * (self.theta - V_W_star), -500.0, 500.0) ) ) h_val = 15.0 * self.beta / ( 1.0 + (1.0 / self.rate_slope) * jnp.exp( jnp.clip(-self.beta * (self.theta - V_W_star), -500.0, 500.0) ) ) dPI = (n_spikes_float - phi_val * dt_ms) * h_val # Store current-step dPI as ShortTermState (accessible from for_loop body). self._dPI.value = dPI # Populate Python history dict only when NOT inside a JAX JIT context. # t is a concrete Quantity during Python loops; a JAX tracer inside for_loop. if not is_tracer(u.math.asarray(t / u.ms)): t_ms_val = ( float(np.asarray(u.math.asarray(t / u.ms))) + float(np.asarray(dt_ms)) ) dPI_np = np.asarray(dPI) for idx in np.ndindex(v_shape): flat_idx = np.ravel_multi_index(idx, v_shape) if len(idx) > 0 else 0 if flat_idx not in self._urbanczik_history: self._urbanczik_history[flat_idx] = [] dpi_val = float(dPI_np[idx]) if dPI_np.ndim > 0 else float(dPI_np) self._urbanczik_history[flat_idx].append((t_ms_val, dpi_val)) # Write back state. self.V_s.value = V_s self.g_ex_s.value = g_ex_s self.g_in_s.value = g_in_s self.V_d.value = V_d self.I_ex_d.value = I_ex_d self.I_in_d.value = I_in_d self.refractory_step_count.value = jnp.asarray(new_r, dtype=ditype) self.integration_step.value = h self.I_stim_soma.value = new_i_stim_soma + u.math.zeros(v_shape) * u.pA last_spike_time = u.math.where(spike_mask, t + dt, self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spike_time) return jnp.asarray(spike_mask, dtype=dftype)
[docs] def get_urbanczik_history(self, neuron_idx=0): r"""Retrieve the Urbanczik-Senn learning signal history for a specific neuron. This method returns the complete time series of error signals (δΠ) computed during simulation, which can be used to implement the Urbanczik-Senn synaptic plasticity rule. Each entry contains the simulation time and corresponding error signal value. Parameters ---------- neuron_idx : int, optional Flat (raveled) index of the neuron within the population array (default: 0). For a 2D population of shape (N, M), valid indices are 0 to N*M-1. Use ``np.ravel_multi_index((i, j), shape)`` to convert multi-dimensional indices to flat index. Returns ------- history : list of tuple List of (time_ms, delta_PI) tuples, where: * ``time_ms`` (float): Simulation time in milliseconds when the signal was computed. Times are strictly increasing. * ``delta_PI`` (float): Learning signal value (dimensionless). Positive values indicate the neuron spiked more than predicted (potentiation signal); negative values indicate under-spiking (depression signal). If the neuron index has not been encountered (no history), returns an empty list ``[]``. Mathematical Interpretation --------------------------- Each ``delta_PI`` value represents: .. math:: \delta\Pi(t) = \left(n_{\mathrm{spikes}}(t) - \phi(V^*_W(t)) \cdot dt\right) \cdot h(V^*_W(t)) where: * :math:`n_{\mathrm{spikes}}` is the actual spike count in the time step * :math:`\phi(V^*_W) \cdot dt` is the expected spike count from prediction * :math:`h(V^*_W)` is the voltage-dependent learning modulation **Usage in Plasticity:** The Urbanczik-Senn weight update rule for a synapse connecting to this neuron involves integrating these error signals with presynaptic activity traces. Typically, weights are updated as: .. math:: \Delta w_i = \eta \sum_t \delta\Pi(t) \cdot x_i(t) where :math:`x_i(t)` is the presynaptic trace (e.g., filtered spike train). Examples -------- **Retrieve and plot learning signals:** .. code-block:: python >>> import brainpy.state as bp >>> import saiunit as u >>> import brainstate >>> import matplotlib.pyplot as plt >>> neuron = bp.pp_cond_exp_mc_urbanczik(in_size=1) >>> neuron.init_all_states() >>> with brainstate.environ.context(dt=0.1*u.ms): ... for i in range(1000): ... neuron.update(x=250.0 * u.pA) >>> history = neuron.get_urbanczik_history(neuron_idx=0) >>> times, dPIs = zip(*history) >>> plt.plot(times, dPIs) >>> plt.xlabel('Time (ms)') >>> plt.ylabel('δΠ') >>> plt.title('Urbanczik Learning Signal') >>> plt.show() **Access for multi-dimensional population:** .. code-block:: python >>> import numpy as np >>> neuron_pop = bp.pp_cond_exp_mc_urbanczik(in_size=(10, 10)) >>> neuron_pop.init_all_states() >>> # Simulate... >>> # Get history for neuron at position (3, 7) >>> flat_idx = np.ravel_multi_index((3, 7), (10, 10)) >>> history_3_7 = neuron_pop.get_urbanczik_history(neuron_idx=flat_idx) **Check if history exists:** .. code-block:: python >>> history = neuron.get_urbanczik_history(neuron_idx=999) >>> if not history: ... print("No history recorded for neuron 999") Notes ----- * History is stored in memory and grows linearly with simulation length. For long simulations or large populations, consider periodic clearing. * History is reset by ``reset_state()`` but persists across ``update()`` calls. * The internal storage is a Python dict mapping flat indices to lists, which is not JAX-compatible but sufficient for post-simulation analysis. * Times are recorded at the **end** of each time step (t + dt), not at the beginning (t). See Also -------- urbanczik_synapse : Synapse model that uses these signals for plasticity """ return self._urbanczik_history.get(neuron_idx, [])