Source code for brainpy_state._brainpy.lif

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import saiunit as u
import jax
from brainstate.typing import ArrayLike, Size

from brainpy_state._base import Neuron

__all__ = [
    'IF', 'LIF', 'ExpIF', 'ExpIFRef', 'AdExIF', 'AdExIFRef', 'LIFRef', 'ALIF',
    'QuaIF', 'AdQuaIF', 'AdQuaIFRef', 'Gif', 'GifRef',
]


class IF(Neuron):
    r"""Integrate-and-Fire (IF) neuron model.

    This class implements the classic Integrate-and-Fire neuron model, one of the simplest
    spiking neuron models. It accumulates input current until the membrane potential reaches
    a threshold, at which point it fires a spike and resets the potential.

    The model is characterized by the following differential equation:

    $$
    \tau \frac{dV}{dt} = -V + R \cdot I(t)
    $$

    Spike condition:
    If $V \geq V_{th}$: emit spike and reset $V = V - V_{th}$ (soft reset) or $V = 0$ (hard reset)

    Parameters
    ----------
    in_size : Size
        Size of the input to the neuron.
    R : ArrayLike, default=1. * u.ohm
        Membrane resistance.
    tau : ArrayLike, default=5. * u.ms
        Membrane time constant.
    V_th : ArrayLike, default=1. * u.mV
        Firing threshold voltage (should be positive).
    V_initializer : Callable
        Initializer for the membrane potential state.
    spk_fun : Callable, default=surrogate.ReluGrad()
        Surrogate gradient function for the non-differentiable spike generation.
    spk_reset : str, default='soft'
        Reset mechanism after spike generation:
        - 'soft': subtract threshold V = V - V_th
        - 'hard': strict reset using stop_gradient
    name : str, optional
        Name of the neuron layer.

    Attributes
    ----------
    V : HiddenState
        Membrane potential.

    See Also
    --------
    LIF : Leaky integrate-and-fire with resting potential.
    QuaIF : Quadratic integrate-and-fire with nonlinear dynamics.

    Notes
    -----
    - Unlike the LIF model, the IF model has no leak towards a resting potential.
    - The membrane potential decays exponentially with time constant tau in the absence of input.
    - The time-dependent dynamics are integrated using an exponential Euler method.
    - The IF model is perfect integrator in the sense that it accumulates input indefinitely
      until reaching threshold, without any leak current.
    - The integrate-and-fire model was first introduced by Lapicque [1]_ [2]_.
    - For a comprehensive review of integrate-and-fire models, see [3]_.

    References
    ----------
    .. [1] Lapicque, L. (1907). Recherches quantitatives sur l'excitation électrique
           des nerfs traitée comme une polarisation. Journal de Physiologie et de
           Pathologie Générale, 9, 620-635.
    .. [2] Abbott, L. F. (1999). Lapicque's introduction of the integrate-and-fire
           model neuron (1907). Brain Research Bulletin, 50(5-6), 303-304.
    .. [3] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model:
           I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19.

    Examples
    --------
    .. code-block:: python

        >>> import brainpy
        >>> import brainstate
        >>> import saiunit as u
        >>> # Create an IF neuron layer with 10 neurons
        >>> if_neuron = brainpy.state.IF(10, tau=8*u.ms, V_th=1.2*u.mV)
        >>> # Initialize the state
        >>> if_neuron.init_state(batch_size=1)
        >>> # Apply an input current and update the neuron state
        >>> spikes = if_neuron.update(x=2.0*u.mA)
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: brainstate.typing.Size,
        R: brainstate.typing.ArrayLike = 1. * u.ohm,
        tau: brainstate.typing.ArrayLike = 5. * u.ms,
        V_th: brainstate.typing.ArrayLike = 1. * u.mV,  # should be positive
        V_initializer: Callable = braintools.init.Constant(0. * u.mV),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'soft',
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        # parameters
        self.R = braintools.init.param(R, self.varshape)
        self.tau = braintools.init.param(tau, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)
        self.V_initializer = V_initializer

[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size)
[docs] def get_spike(self, V=None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / self.V_th return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): # reset last_V = self.V.value last_spike = self.get_spike(self.V.value) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V) V = last_V - V_th * last_spike # membrane potential dv = lambda v: (-v + self.R * self.sum_current_inputs(x, v)) / self.tau V = brainstate.nn.exp_euler_step(dv, V) V = self.sum_delta_inputs(V) self.V.value = V return self.get_spike(V) class LIF(Neuron): r"""Leaky Integrate-and-Fire (LIF) neuron model. This class implements the Leaky Integrate-and-Fire neuron model, which extends the basic Integrate-and-Fire model by adding a leak term. The leak causes the membrane potential to decay towards a resting value in the absence of input, making the model more biologically plausible. The model is characterized by the following differential equation: $$ \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t) $$ Spike condition: If $V \geq V_{th}$: emit spike and reset $V = V_{reset}$ Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=5. * u.ms Membrane time constant. V_th : ArrayLike, default=1. * u.mV Firing threshold voltage. V_reset : ArrayLike, default=0. * u.mV Reset voltage after spike. V_rest : ArrayLike, default=0. * u.mV Resting membrane potential. V_initializer : Callable Initializer for the membrane potential state. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function for the non-differentiable spike generation. spk_reset : str, default='soft' Reset mechanism after spike generation: - 'soft': subtract threshold V = V - V_th - 'hard': strict reset using stop_gradient name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. See Also -------- IF : Perfect integrator without leak. LIFRef : LIF with absolute refractory period. ALIF : Adaptive LIF with spike-frequency adaptation. ExpIF : LIF with exponential spike initiation. Notes ----- - The leak term causes the membrane potential to decay exponentially towards V_rest with time constant tau when no input is present. - The time-dependent dynamics are integrated using an exponential Euler method. - Spike generation is non-differentiable, so surrogate gradients are used for backpropagation during training. - For a detailed treatment of LIF models, see [1]_ and [2]_. References ---------- .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press. .. [2] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model: I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create a LIF neuron layer with 10 neurons >>> lif = brainpy.state.LIF(10, tau=10*u.ms, V_th=0.8*u.mV) >>> # Initialize the state >>> lif.init_state(batch_size=1) >>> # Apply an input current and update the neuron state >>> spikes = lif.update(x=1.5*u.mA) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 5. * u.ms, V_th: ArrayLike = 1. * u.mV, V_reset: ArrayLike = 0. * u.mV, V_rest: ArrayLike = 0. * u.mV, V_initializer: Callable = braintools.init.Constant(0. * u.mV), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_initializer = V_initializer
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size)
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): last_v = self.V.value last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) V = last_v - (V_th - self.V_reset) * last_spk # membrane potential dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau V = brainstate.nn.exp_euler_step(dv, V) V = self.sum_delta_inputs(V) self.V.value = V return self.get_spike(V) class ExpIF(Neuron): r"""Exponential Integrate-and-Fire (ExpIF) neuron model. This model augments the LIF neuron by adding an exponential spike-initiation term, which provides a smooth approximation of the action potential onset and improves biological plausibility for cortical pyramidal cells. The membrane potential dynamics follow: $$ \tau \frac{dV}{dt} = -(V - V_{rest}) + \Delta_T \exp\left(\frac{V - V_T}{\Delta_T}\right) + R \cdot I(t) $$ Spike condition: If $V \geq V_{th}$: emit spike and reset $V = V_{reset}$ (hard reset) or $V = V - (V_{th} - V_{reset})$ (soft reset). Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=10. * u.ms Membrane time constant. V_th : ArrayLike, default=-30. * u.mV Numerical firing threshold voltage. V_reset : ArrayLike, default=-68. * u.mV Reset voltage after spike. V_rest : ArrayLike, default=-65. * u.mV Resting membrane potential. V_T : ArrayLike, default=-59.9 * u.mV Threshold potential of the exponential term. delta_T : ArrayLike, default=3.48 * u.mV Spike slope factor controlling the sharpness of spike initiation. V_initializer : Callable Initializer for the membrane potential state. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function for the spike generation. spk_reset : str, default='soft' Reset mechanism after spike generation. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. See Also -------- ExpIFRef : ExpIF with absolute refractory period. AdExIF : Adaptive exponential integrate-and-fire. LIF : Simpler leaky integrate-and-fire without exponential term. Notes ----- - The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. It is one of the prominent examples of a precise theoretical prediction in computational neuroscience that was later confirmed by experimental neuroscience. - The right-hand side of the above equation contains a nonlinearity that can be directly extracted from experimental data [3]_. In this sense the exponential nonlinearity is not an arbitrary choice but directly supported by experimental evidence. - Even though it is a nonlinear model, it is simple enough to calculate the firing rate for constant input, and the linear response to fluctuations, even in the presence of input noise [4]_. - For a comprehensive treatment of this model, see [2]_ and [5]_. References ---------- .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation mechanisms determine the neuronal response to fluctuating inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. .. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press. .. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves are reliable predictors of naturalistic pyramidal-neuron voltage traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. .. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear integrate-and-fire neurons to modulated current-based and conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. .. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create an ExpIF neuron layer with 10 neurons >>> expif = brainpy.state.ExpIF(10, tau=10*u.ms, V_th=-30*u.mV) >>> # Initialize the state >>> expif.init_state(batch_size=1) >>> # Apply an input current and update the neuron state >>> spikes = expif.update(x=1.5*u.mA) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 10. * u.ms, V_th: ArrayLike = -30. * u.mV, V_reset: ArrayLike = -68. * u.mV, V_rest: ArrayLike = -65. * u.mV, V_T: ArrayLike = -59.9 * u.mV, delta_T: ArrayLike = 3.48 * u.mV, V_initializer: Callable = braintools.init.Constant(-65. * u.mV), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_T = braintools.init.param(V_T, self.varshape) self.delta_T = braintools.init.param(delta_T, self.varshape) self.V_initializer = V_initializer
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size)
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): last_v = self.V.value last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) V = last_v - (V_th - self.V_reset) * last_spk def dv(v): exp_term = self.delta_T * u.math.exp((v - self.V_T) / self.delta_T) return (-(v - self.V_rest) + exp_term + self.R * self.sum_current_inputs(x, v)) / self.tau V = brainstate.nn.exp_euler_step(dv, V) V = self.sum_delta_inputs(V) self.V.value = V return self.get_spike(V) class ExpIFRef(Neuron): r"""Exponential Integrate-and-Fire neuron model with refractory mechanism. This neuron adds an absolute refractory period to :class:`ExpIF`. While the exponential spike-initiation term keeps the membrane potential dynamics smooth, the refractory mechanism prevents the neuron from firing within ``tau_ref`` after a spike. Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=10. * u.ms Membrane time constant. tau_ref : ArrayLike, default=1.7 * u.ms Absolute refractory period duration. V_th : ArrayLike, default=-30. * u.mV Numerical firing threshold voltage. V_reset : ArrayLike, default=-68. * u.mV Reset voltage after spike. V_rest : ArrayLike, default=-65. * u.mV Resting membrane potential. V_T : ArrayLike, default=-59.9 * u.mV Threshold potential of the exponential term. delta_T : ArrayLike, default=3.48 * u.mV Spike slope factor controlling spike initiation sharpness. V_initializer : Callable Initializer for the membrane potential state. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function for the spike generation. spk_reset : str, default='soft' Reset mechanism after spike generation. ref_var : bool, default=False Whether to expose a boolean refractory state variable. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. last_spike_time : ShortTermState Last spike time recorder. refractory : HiddenState Neuron refractory state. See Also -------- ExpIF : ExpIF without refractory period. AdExIFRef : Adaptive ExpIF with refractory period. Notes ----- - The refractory mechanism prevents the neuron from firing within ``tau_ref`` after a spike by holding the membrane potential at the reset value. - The simulation environment time variable ``t`` must be available via ``brainstate.environ.get('t')`` for refractory tracking. References ---------- .. [1] Fourcaud-Trocme, N., et al. (2003). How spike generation mechanisms determine the neuronal response to fluctuating inputs. Journal of Neuroscience, 23(37), 11628-11640. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create an ExpIFRef neuron layer with 10 neurons >>> expif = brainpy.state.ExpIFRef(10, tau=10*u.ms, tau_ref=1.7*u.ms) >>> expif.init_state(batch_size=1) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 10. * u.ms, tau_ref: ArrayLike = 1.7 * u.ms, V_th: ArrayLike = -30. * u.mV, V_reset: ArrayLike = -68. * u.mV, V_rest: ArrayLike = -65. * u.mV, V_T: ArrayLike = -59.9 * u.mV, delta_T: ArrayLike = 3.48 * u.mV, V_initializer: Callable = braintools.init.Constant(-65. * u.mV), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', ref_var: bool = False, name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.tau_ref = braintools.init.param(tau_ref, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_T = braintools.init.param(V_T, self.varshape) self.delta_T = braintools.init.param(delta_T, self.varshape) self.V_initializer = V_initializer self.ref_var = ref_var
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.last_spike_time = brainstate.ShortTermState( braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) ) if self.ref_var: self.refractory = brainstate.HiddenState( braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) )
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.last_spike_time.value = braintools.init.param( braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size ) if self.ref_var: self.refractory.value = braintools.init.param( braintools.init.Constant(False), self.varshape, batch_size )
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): t = brainstate.environ.get('t') last_v = self.V.value last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) v_reset = last_v - (V_th - self.V_reset) * last_spk def dv(v): exp_term = self.delta_T * u.math.exp((v - self.V_T) / self.delta_T) return (-(v - self.V_rest) + exp_term + self.R * self.sum_current_inputs(x, v)) / self.tau V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) V_candidate = self.sum_delta_inputs(V_candidate) refractory = (t - self.last_spike_time.value) < self.tau_ref self.V.value = u.math.where(refractory, v_reset, V_candidate) spike_cond = self.V.value >= self.V_th self.last_spike_time.value = jax.lax.stop_gradient( u.math.where(spike_cond, t, self.last_spike_time.value) ) if self.ref_var: self.refractory.value = jax.lax.stop_gradient( u.math.logical_or(refractory, spike_cond) ) return self.get_spike() class AdExIF(Neuron): r"""Adaptive exponential Integrate-and-Fire (AdExIF) neuron model. This model extends :class:`ExpIF` by adding an adaptation current ``w`` that is incremented after each spike and relaxes with time constant ``tau_w``. The membrane dynamics are governed by two coupled differential equations [1]_: $$ \tau \frac{dV}{dt} = -(V - V_{rest}) + \Delta_T \exp\left(\frac{V - V_T}{\Delta_T}\right) - R w + R \cdot I(t) $$ $$ \tau_w \frac{dw}{dt} = a (V - V_{rest}) - w $$ After each spike the membrane potential is reset and the adaptation current increases by ``b``. This simple mechanism generates rich firing patterns such as spike-frequency adaptation and bursting. Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=10. * u.ms Membrane time constant. tau_w : ArrayLike, default=30. * u.ms Adaptation current time constant. V_th : ArrayLike, default=-55. * u.mV Spike threshold used for reset. V_reset : ArrayLike, default=-68. * u.mV Reset potential after spike. V_rest : ArrayLike, default=-65. * u.mV Resting membrane potential. V_T : ArrayLike, default=-59.9 * u.mV Threshold of the exponential term. delta_T : ArrayLike, default=3.48 * u.mV Spike slope factor controlling the sharpness of spike initiation. a : ArrayLike, default=1. * u.siemens Coupling strength from voltage to adaptation current. b : ArrayLike, default=1. * u.mA Increment of the adaptation current after a spike. V_initializer : Callable Initializer for the membrane potential state. w_initializer : Callable Initializer for the adaptation current. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function for the spike generation. spk_reset : str, default='soft' Reset mechanism after spike generation. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. w : HiddenState Adaptation current. See Also -------- AdExIFRef : AdExIF with absolute refractory period. ExpIF : Exponential IF without adaptation. LIF : Simpler leaky integrate-and-fire. Notes ----- - The AdEx model can reproduce a wide variety of neuronal firing patterns including regular spiking, bursting, and spike-frequency adaptation. - For detailed information about this model and its parameters, see [1]_ and [2]_. References ---------- .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation mechanisms determine the neuronal response to fluctuating inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model .. seealso:: :class:`brainpy.dyn.AdExIF` for the dynamical-system counterpart. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create an AdExIF neuron layer with 10 neurons >>> adexif = brainpy.state.AdExIF(10, tau=10*u.ms) >>> # Initialize the state >>> adexif.init_state(batch_size=1) >>> # Apply an input current and update the neuron state >>> spikes = adexif.update(x=1.5*u.mA) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 10. * u.ms, tau_w: ArrayLike = 30. * u.ms, V_th: ArrayLike = -55. * u.mV, V_reset: ArrayLike = -68. * u.mV, V_rest: ArrayLike = -65. * u.mV, V_T: ArrayLike = -59.9 * u.mV, delta_T: ArrayLike = 3.48 * u.mV, a: ArrayLike = 1. * u.siemens, b: ArrayLike = 1. * u.mA, V_initializer: Callable = braintools.init.Constant(-65. * u.mV), w_initializer: Callable = braintools.init.Constant(0. * u.mA), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.tau_w = braintools.init.param(tau_w, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_T = braintools.init.param(V_T, self.varshape) self.delta_T = braintools.init.param(delta_T, self.varshape) self.a = braintools.init.param(a, self.varshape) self.b = braintools.init.param(b, self.varshape) # initializers self.V_initializer = V_initializer self.w_initializer = w_initializer
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.w = brainstate.HiddenState(braintools.init.param(self.w_initializer, self.varshape, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.w.value = braintools.init.param(self.w_initializer, self.varshape, batch_size)
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): last_v = self.V.value last_w = self.w.value last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) V = last_v - (V_th - self.V_reset) * last_spk w = last_w + self.b * last_spk def dv(v): exp_term = self.delta_T * u.math.exp((v - self.V_T) / self.delta_T) I_total = self.sum_current_inputs(x, v) return (-(v - self.V_rest) + exp_term - self.R * w + self.R * I_total) / self.tau V = brainstate.nn.exp_euler_step(dv, V) V = self.sum_delta_inputs(V) def dw_func(w_val): return (self.a * (V - self.V_rest) - w_val) / self.tau_w w = brainstate.nn.exp_euler_step(dw_func, w) self.V.value = V self.w.value = w return self.get_spike(self.V.value) class AdExIFRef(Neuron): r"""Adaptive exponential Integrate-and-Fire neuron model with refractory mechanism. This model extends :class:`AdExIF` by adding an absolute refractory period. While the exponential spike-initiation term and adaptation current keep the membrane potential dynamics biologically realistic, the refractory mechanism prevents the neuron from firing within ``tau_ref`` after a spike. The membrane dynamics are governed by two coupled differential equations: $$ \tau \frac{dV}{dt} = -(V - V_{rest}) + \Delta_T \exp\left(\frac{V - V_T}{\Delta_T}\right) - R w + R \cdot I(t) $$ $$ \tau_w \frac{dw}{dt} = a (V - V_{rest}) - w $$ After each spike the membrane potential is reset and the adaptation current increases by ``b``. During the refractory period, the membrane potential remains at the reset value. Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=10. * u.ms Membrane time constant. tau_w : ArrayLike, default=30. * u.ms Adaptation current time constant. tau_ref : ArrayLike, default=1.7 * u.ms Absolute refractory period duration. V_th : ArrayLike, default=-55. * u.mV Spike threshold used for reset. V_reset : ArrayLike, default=-68. * u.mV Reset potential after spike. V_rest : ArrayLike, default=-65. * u.mV Resting membrane potential. V_T : ArrayLike, default=-59.9 * u.mV Threshold of the exponential term. delta_T : ArrayLike, default=3.48 * u.mV Spike slope factor controlling the sharpness of spike initiation. a : ArrayLike, default=1. * u.siemens Coupling strength from voltage to adaptation current. b : ArrayLike, default=1. * u.mA Increment of the adaptation current after a spike. V_initializer : Callable Initializer for the membrane potential state. w_initializer : Callable Initializer for the adaptation current. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function for the spike generation. spk_reset : str, default='soft' Reset mechanism after spike generation. ref_var : bool, default=False Whether to expose a boolean refractory state variable. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. w : HiddenState Adaptation current. last_spike_time : ShortTermState Last spike time recorder. refractory : HiddenState Neuron refractory state (if ref_var=True). See Also -------- AdExIF : AdExIF without refractory period. ExpIFRef : ExpIF with refractory period but no adaptation. Notes ----- - The AdExIF model with refractory period combines adaptation dynamics with an absolute refractory period for more biologically realistic behavior. - For detailed information about this model, see [1]_ and [2]_. References ---------- .. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation mechanisms determine the neuronal response to fluctuating inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. .. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model .. seealso:: :class:`brainpy.dyn.AdExIFRef` for the dynamical-system counterpart. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create an AdExIFRef neuron layer with 10 neurons >>> adexif_ref = brainpy.state.AdExIFRef(10, tau=10*u.ms, tau_ref=2*u.ms) >>> # Initialize the state >>> adexif_ref.init_state(batch_size=1) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 10. * u.ms, tau_w: ArrayLike = 30. * u.ms, tau_ref: ArrayLike = 1.7 * u.ms, V_th: ArrayLike = -55. * u.mV, V_reset: ArrayLike = -68. * u.mV, V_rest: ArrayLike = -65. * u.mV, V_T: ArrayLike = -59.9 * u.mV, delta_T: ArrayLike = 3.48 * u.mV, a: ArrayLike = 1. * u.siemens, b: ArrayLike = 1. * u.mA, V_initializer: Callable = braintools.init.Constant(-65. * u.mV), w_initializer: Callable = braintools.init.Constant(0. * u.mA), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', ref_var: bool = False, name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.tau_w = braintools.init.param(tau_w, self.varshape) self.tau_ref = braintools.init.param(tau_ref, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_T = braintools.init.param(V_T, self.varshape) self.delta_T = braintools.init.param(delta_T, self.varshape) self.a = braintools.init.param(a, self.varshape) self.b = braintools.init.param(b, self.varshape) # initializers self.V_initializer = V_initializer self.w_initializer = w_initializer self.ref_var = ref_var
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.w = brainstate.HiddenState(braintools.init.param(self.w_initializer, self.varshape, batch_size)) self.last_spike_time = brainstate.ShortTermState( braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) ) if self.ref_var: self.refractory = brainstate.HiddenState( braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) )
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.w.value = braintools.init.param(self.w_initializer, self.varshape, batch_size) self.last_spike_time.value = braintools.init.param( braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size ) if self.ref_var: self.refractory.value = braintools.init.param( braintools.init.Constant(False), self.varshape, batch_size )
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): t = brainstate.environ.get('t') last_v = self.V.value last_w = self.w.value last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) v_reset = last_v - (V_th - self.V_reset) * last_spk w_reset = last_w + self.b * last_spk def dv(v): exp_term = self.delta_T * u.math.exp((v - self.V_T) / self.delta_T) I_total = self.sum_current_inputs(x, v) return (-(v - self.V_rest) + exp_term - self.R * w_reset + self.R * I_total) / self.tau V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) V_candidate = self.sum_delta_inputs(V_candidate) def dw_func(w_val): return (self.a * (V_candidate - self.V_rest) - w_val) / self.tau_w w_candidate = brainstate.nn.exp_euler_step(dw_func, w_reset) refractory = (t - self.last_spike_time.value) < self.tau_ref self.V.value = u.math.where(refractory, v_reset, V_candidate) self.w.value = u.math.where(refractory, w_reset, w_candidate) spike_cond = self.V.value >= self.V_th self.last_spike_time.value = jax.lax.stop_gradient( u.math.where(spike_cond, t, self.last_spike_time.value) ) if self.ref_var: self.refractory.value = jax.lax.stop_gradient( u.math.logical_or(refractory, spike_cond) ) return self.get_spike() class LIFRef(Neuron): r"""Leaky Integrate-and-Fire neuron model with refractory period. This class implements a Leaky Integrate-and-Fire neuron model that includes a refractory period after spiking, during which the neuron cannot fire regardless of input. This better captures the behavior of biological neurons that exhibit a recovery period after action potential generation. The model is characterized by the following equations: When not in refractory period: $$ \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t) $$ During refractory period: $$ V = V_{reset} $$ Spike condition: If $V \geq V_{th}$: emit spike, set $V = V_{reset}$, and enter refractory period for $\tau_{ref}$ Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=5. * u.ms Membrane time constant. tau_ref : ArrayLike, default=5. * u.ms Refractory period duration. V_th : ArrayLike, default=1. * u.mV Firing threshold voltage. V_reset : ArrayLike, default=0. * u.mV Reset voltage after spike. V_rest : ArrayLike, default=0. * u.mV Resting membrane potential. V_initializer : Callable Initializer for the membrane potential state. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function for the non-differentiable spike generation. spk_reset : str, default='soft' Reset mechanism after spike generation: - 'soft': subtract threshold V = V - V_th - 'hard': strict reset using stop_gradient name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. last_spike_time : ShortTermState Time of the last spike, used to implement refractory period. See Also -------- LIF : LIF without refractory period. ALIF : Adaptive LIF with spike-frequency adaptation. Notes ----- - The refractory period is implemented by tracking the time of the last spike and preventing membrane potential updates if the elapsed time is less than tau_ref. - During the refractory period, the membrane potential remains at the reset value regardless of input current strength. - Refractory periods prevent high-frequency repetitive firing and are critical for realistic neural dynamics [3]_. - The time-dependent dynamics are integrated using an exponential Euler method. - The simulation environment time variable 't' is used to track the refractory state. - For a comprehensive treatment of LIF models with refractory periods, see [1]_ and [2]_. References ---------- .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press. .. [2] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model: I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19. .. [3] Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on neural networks, 14(6), 1569-1572. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create a LIFRef neuron layer with 10 neurons >>> lifref = brainpy.state.LIFRef(10, tau=10*u.ms, tau_ref=5*u.ms, ... V_th=0.8*u.mV) >>> # Initialize the state >>> lifref.init_state(batch_size=1) >>> # Apply an input current and update the neuron state >>> spikes = lifref.update(x=1.5*u.mA) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 5. * u.ms, tau_ref: ArrayLike = 5. * u.ms, V_th: ArrayLike = 1. * u.mV, V_reset: ArrayLike = 0. * u.mV, V_rest: ArrayLike = 0. * u.mV, V_initializer: Callable = braintools.init.Constant(0. * u.mV), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.tau_ref = braintools.init.param(tau_ref, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_initializer = V_initializer
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.last_spike_time = brainstate.ShortTermState( braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) )
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.last_spike_time.value = braintools.init.param( braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size )
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): t = brainstate.environ.get('t') last_v = self.V.value last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) last_v = last_v - (V_th - self.V_reset) * last_spk # membrane potential dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau V = brainstate.nn.exp_euler_step(dv, last_v) V = self.sum_delta_inputs(V) self.V.value = u.math.where(t - self.last_spike_time.value < self.tau_ref, last_v, V) # spike time evaluation last_spk_time = u.math.where( self.V.value >= self.V_th, brainstate.environ.get('t'), self.last_spike_time.value) self.last_spike_time.value = jax.lax.stop_gradient(last_spk_time) return self.get_spike() class ALIF(Neuron): r"""Adaptive Leaky Integrate-and-Fire (ALIF) neuron model. This class implements the Adaptive Leaky Integrate-and-Fire neuron model, which extends the basic LIF model by adding an adaptation variable. This adaptation mechanism increases the effective firing threshold after each spike, allowing the neuron to exhibit spike-frequency adaptation - a common feature in biological neurons that reduces firing rate during sustained stimulation. The model is characterized by the following differential equations: $$ \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t) $$ $$ \tau_a \frac{da}{dt} = -a $$ Spike condition: If $V \geq V_{th} + \beta \cdot a$: emit spike, set $V = V_{reset}$, and increment $a = a + 1$ Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=5. * u.ms Membrane time constant. tau_a : ArrayLike, default=100. * u.ms Adaptation time constant (typically much longer than tau). V_th : ArrayLike, default=1. * u.mV Base firing threshold voltage. V_reset : ArrayLike, default=0. * u.mV Reset voltage after spike. V_rest : ArrayLike, default=0. * u.mV Resting membrane potential. beta : ArrayLike, default=0.1 * u.mV Adaptation coupling parameter that scales the effect of the adaptation variable. spk_fun : Callable Surrogate gradient function for the non-differentiable spike generation. spk_reset : str, default='soft' Reset mechanism after spike generation: - 'soft': subtract threshold V = V - V_th - 'hard': strict reset using stop_gradient V_initializer : Callable Initializer for the membrane potential state. a_initializer : Callable Initializer for the adaptation variable. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. a : HiddenState Adaptation variable that increases after each spike and decays exponentially. See Also -------- LIF : Standard LIF without adaptation. LIFRef : LIF with refractory period. AdExIF : Adaptive exponential integrate-and-fire. Notes ----- - The adaptation variable 'a' increases by 1 with each spike and decays exponentially with time constant tau_a between spikes. - The effective threshold increases by beta*a, making it progressively harder for the neuron to fire when it has recently been active. - This adaptation mechanism creates spike-frequency adaptation [2]_, allowing the neuron to respond strongly to input onset but then reduce its firing rate even if the input remains constant. - The adaptation time constant tau_a is typically much larger than the membrane time constant tau, creating a longer-lasting adaptation effect. - The time-dependent dynamics are integrated using an exponential Euler method. - For detailed analysis of adaptive integrate-and-fire models, see [1]_ and [3]_. References ---------- .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press. .. [2] Brette, R., & Gerstner, W. (2005). Adaptive exponential integrate-and-fire model as an effective description of neuronal activity. Journal of neurophysiology, 94(5), 3637-3642. .. [3] Naud, R., Marcille, N., Clopath, C., & Gerstner, W. (2008). Firing patterns in the adaptive exponential integrate-and-fire model. Biological cybernetics, 99(4), 335-347. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create an ALIF neuron layer with 10 neurons >>> alif = brainpy.state.ALIF(10, tau=10*u.ms, tau_a=200*u.ms, ... beta=0.2*u.mV) >>> # Initialize the state >>> alif.init_state(batch_size=1) >>> # Apply an input current and update the neuron state >>> spikes = alif.update(x=1.5*u.mA) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 5. * u.ms, tau_a: ArrayLike = 100. * u.ms, V_th: ArrayLike = 1. * u.mV, V_reset: ArrayLike = 0. * u.mV, V_rest: ArrayLike = 0. * u.mV, beta: ArrayLike = 0.1 * u.mV, spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', V_initializer: Callable = braintools.init.Constant(0. * u.mV), a_initializer: Callable = braintools.init.Constant(0.), name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.tau_a = braintools.init.param(tau_a, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.beta = braintools.init.param(beta, self.varshape) # functions self.V_initializer = V_initializer self.a_initializer = a_initializer
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.a = brainstate.HiddenState(braintools.init.param(self.a_initializer, self.varshape, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.a.value = braintools.init.param(self.a_initializer, self.varshape, batch_size)
[docs] def get_spike(self, V=None, a=None): V = self.V.value if V is None else V a = self.a.value if a is None else a v_scaled = (V - self.V_th - self.beta * a) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): last_v = self.V.value last_a = self.a.value lst_spk = self.get_spike(last_v, last_a) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) V = last_v - (V_th - self.V_reset) * lst_spk a = last_a + lst_spk # membrane potential dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau da = lambda a: -a / self.tau_a V = brainstate.nn.exp_euler_step(dv, V) a = brainstate.nn.exp_euler_step(da, a) self.V.value = self.sum_delta_inputs(V) self.a.value = a return self.get_spike(self.V.value, self.a.value) class QuaIF(Neuron): r"""Quadratic Integrate-and-Fire (QuaIF) neuron model. This model extends the basic integrate-and-fire neuron by adding a quadratic nonlinearity in the voltage dynamics. The quadratic term creates a soft spike initiation, making the model more biologically realistic than the linear IF model. The model is characterized by the following differential equation: $$ \tau \frac{dV}{dt} = c(V - V_{rest})(V - V_c) + R \cdot I(t) $$ Spike condition: If $V \geq V_{th}$: emit spike and reset $V = V_{reset}$ Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=10. * u.ms Membrane time constant. V_th : ArrayLike, default=-30. * u.mV Firing threshold voltage. V_reset : ArrayLike, default=-68. * u.mV Reset voltage after spike. V_rest : ArrayLike, default=-65. * u.mV Resting membrane potential. V_c : ArrayLike, default=-50. * u.mV Critical voltage for spike initiation. Must be larger than V_rest. c : ArrayLike, default=0.07 / u.mV Coefficient describing membrane potential update. Larger than 0. V_initializer : Callable Initializer for the membrane potential state. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function for the spike generation. spk_reset : str, default='soft' Reset mechanism after spike generation. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. See Also -------- AdQuaIF : Adaptive quadratic integrate-and-fire. AdQuaIFRef : Adaptive quadratic IF with refractory period. ExpIF : Exponential integrate-and-fire (alternative nonlinear model). Notes ----- - The quadratic nonlinearity provides a more realistic spike initiation compared to LIF. - The critical voltage V_c determines the onset of spike generation. - When V approaches V_c, the quadratic term causes rapid acceleration toward threshold. - This model can exhibit Type I excitability (continuous f-I curve) [1]_. References ---------- .. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg (2000) Intrinsic dynamics in neuronal networks. I. Theory. J. Neurophysiology 83, pp. 808–827. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create a QuaIF neuron layer with 10 neurons >>> quaif = brainpy.state.QuaIF(10, tau=10*u.ms, V_th=-30*u.mV, ... V_c=-50*u.mV) >>> # Initialize the state >>> quaif.init_state(batch_size=1) >>> # Apply an input current and update the neuron state >>> spikes = quaif.update(x=2.5*u.mA) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 10. * u.ms, V_th: ArrayLike = -30. * u.mV, V_reset: ArrayLike = -68. * u.mV, V_rest: ArrayLike = -65. * u.mV, V_c: ArrayLike = -50. * u.mV, c: ArrayLike = 0.07 / u.mV, V_initializer: Callable = braintools.init.Constant(-65. * u.mV), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_c = braintools.init.param(V_c, self.varshape) self.c = braintools.init.param(c, self.varshape) self.V_initializer = V_initializer
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size)
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): last_v = self.V.value last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) V = last_v - (V_th - self.V_reset) * last_spk def dv(v): return (self.c * (v - self.V_rest) * (v - self.V_c) + self.R * self.sum_current_inputs(x, v)) / self.tau V = brainstate.nn.exp_euler_step(dv, V) V = self.sum_delta_inputs(V) self.V.value = V return self.get_spike(V) class AdQuaIF(Neuron): r"""Adaptive Quadratic Integrate-and-Fire (AdQuaIF) neuron model. This model extends the QuaIF model by adding an adaptation current that increases after each spike and decays exponentially between spikes. The adaptation mechanism produces spike-frequency adaptation and enables the neuron to exhibit various firing patterns. The model is characterized by the following differential equations: $$ \tau \frac{dV}{dt} = c(V - V_{rest})(V - V_c) - w + R \cdot I(t) $$ $$ \tau_w \frac{dw}{dt} = a(V - V_{rest}) - w $$ After a spike: $V \rightarrow V_{reset}$, $w \rightarrow w + b$ Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=10. * u.ms Membrane time constant. tau_w : ArrayLike, default=10. * u.ms Adaptation current time constant. V_th : ArrayLike, default=-30. * u.mV Firing threshold voltage. V_reset : ArrayLike, default=-68. * u.mV Reset voltage after spike. V_rest : ArrayLike, default=-65. * u.mV Resting membrane potential. V_c : ArrayLike, default=-50. * u.mV Critical voltage for spike initiation. c : ArrayLike, default=0.07 / u.mV Coefficient describing membrane potential update. a : ArrayLike, default=1. * u.siemens Coupling strength from voltage to adaptation current. b : ArrayLike, default=0.1 * u.mA Increment of adaptation current after a spike. V_initializer : Callable Initializer for the membrane potential state. w_initializer : Callable Initializer for the adaptation current. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function. spk_reset : str, default='soft' Reset mechanism after spike generation. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. w : HiddenState Adaptation current. See Also -------- QuaIF : Quadratic IF without adaptation. AdQuaIFRef : Adaptive quadratic IF with refractory period. AdExIF : Adaptive exponential IF (alternative adaptive model). Notes ----- - The adaptation current w provides negative feedback, reducing firing rate. - Parameter 'a' controls subthreshold adaptation (coupling from V to w). - Parameter 'b' controls spike-triggered adaptation (increment after spike). - With appropriate parameters, can exhibit regular spiking, bursting, etc. [1]_. - The adaptation time constant tau_w determines adaptation speed. - For a detailed bifurcation analysis of this model class, see [2]_. References ---------- .. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. .. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of nonlinear integrate-and-fire neurons." SIAM Journal on Applied Mathematics 68, no. 4 (2008): 1045-1079. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create an AdQuaIF neuron layer with 10 neurons >>> adquaif = brainpy.state.AdQuaIF(10, tau=10*u.ms, tau_w=100*u.ms, ... a=1.0*u.siemens, b=0.1*u.mA) >>> # Initialize the state >>> adquaif.init_state(batch_size=1) >>> # Apply an input current and observe spike-frequency adaptation >>> spikes = adquaif.update(x=3.0*u.mA) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 10. * u.ms, tau_w: ArrayLike = 10. * u.ms, V_th: ArrayLike = -30. * u.mV, V_reset: ArrayLike = -68. * u.mV, V_rest: ArrayLike = -65. * u.mV, V_c: ArrayLike = -50. * u.mV, c: ArrayLike = 0.07 / u.mV, a: ArrayLike = 1. * u.siemens, b: ArrayLike = 0.1 * u.mA, V_initializer: Callable = braintools.init.Constant(-65. * u.mV), w_initializer: Callable = braintools.init.Constant(0. * u.mA), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.tau_w = braintools.init.param(tau_w, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_c = braintools.init.param(V_c, self.varshape) self.c = braintools.init.param(c, self.varshape) self.a = braintools.init.param(a, self.varshape) self.b = braintools.init.param(b, self.varshape) self.V_initializer = V_initializer self.w_initializer = w_initializer
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.w = brainstate.HiddenState(braintools.init.param(self.w_initializer, self.varshape, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.w.value = braintools.init.param(self.w_initializer, self.varshape, batch_size)
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): last_v = self.V.value last_w = self.w.value last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) V = last_v - (V_th - self.V_reset) * last_spk w = last_w + self.b * last_spk def dv(v): return (self.c * (v - self.V_rest) * (v - self.V_c) - self.R * w + self.R * self.sum_current_inputs(x, v)) / self.tau def dw_func(w_val): return (self.a * (V - self.V_rest) - w_val) / self.tau_w V = brainstate.nn.exp_euler_step(dv, V) V = self.sum_delta_inputs(V) w = brainstate.nn.exp_euler_step(dw_func, w) self.V.value = V self.w.value = w return self.get_spike(V) class AdQuaIFRef(Neuron): r"""Adaptive Quadratic Integrate-and-Fire neuron model with refractory mechanism. This model extends AdQuaIF by adding an absolute refractory period during which the neuron cannot fire regardless of input. The combination of adaptation and refractory period creates realistic firing patterns. Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=1. * u.ohm Membrane resistance. tau : ArrayLike, default=10. * u.ms Membrane time constant. tau_w : ArrayLike, default=10. * u.ms Adaptation current time constant. tau_ref : ArrayLike, default=1.7 * u.ms Absolute refractory period duration. V_th : ArrayLike, default=-30. * u.mV Firing threshold voltage. V_reset : ArrayLike, default=-68. * u.mV Reset voltage after spike. V_rest : ArrayLike, default=-65. * u.mV Resting membrane potential. V_c : ArrayLike, default=-50. * u.mV Critical voltage for spike initiation. c : ArrayLike, default=0.07 / u.mV Coefficient describing membrane potential update. a : ArrayLike, default=1. * u.siemens Coupling strength from voltage to adaptation current. b : ArrayLike, default=0.1 * u.mA Increment of adaptation current after a spike. V_initializer : Callable Initializer for the membrane potential state. w_initializer : Callable Initializer for the adaptation current. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function. spk_reset : str, default='soft' Reset mechanism after spike generation. ref_var : bool, default=False Whether to expose a boolean refractory state variable. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. w : HiddenState Adaptation current. last_spike_time : ShortTermState Last spike time recorder. refractory : HiddenState Neuron refractory state (if ref_var=True). See Also -------- AdQuaIF : Adaptive quadratic IF without refractory period. QuaIF : Quadratic IF without adaptation or refractory. Notes ----- - Combines spike-frequency adaptation with absolute refractory period. - During refractory period, neuron state is held at reset values. - Set ref_var=True to track refractory state as a boolean variable. - Refractory period prevents unrealistically high firing rates. - More biologically realistic than AdQuaIF without refractory period. References ---------- .. [1] Touboul, Jonathan. "Bifurcation analysis of a general class of nonlinear integrate-and-fire neurons." SIAM Journal on Applied Mathematics 68, no. 4 (2008): 1045-1079. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create an AdQuaIFRef neuron layer with refractory period >>> adquaif_ref = brainpy.state.AdQuaIFRef(10, tau=10*u.ms, ... tau_w=100*u.ms, tau_ref=2.0*u.ms, ref_var=True) >>> # Initialize the state >>> adquaif_ref.init_state(batch_size=1) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 1. * u.ohm, tau: ArrayLike = 10. * u.ms, tau_w: ArrayLike = 10. * u.ms, tau_ref: ArrayLike = 1.7 * u.ms, V_th: ArrayLike = -30. * u.mV, V_reset: ArrayLike = -68. * u.mV, V_rest: ArrayLike = -65. * u.mV, V_c: ArrayLike = -50. * u.mV, c: ArrayLike = 0.07 / u.mV, a: ArrayLike = 1. * u.siemens, b: ArrayLike = 0.1 * u.mA, V_initializer: Callable = braintools.init.Constant(-65. * u.mV), w_initializer: Callable = braintools.init.Constant(0. * u.mA), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', ref_var: bool = False, name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.tau_w = braintools.init.param(tau_w, self.varshape) self.tau_ref = braintools.init.param(tau_ref, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_c = braintools.init.param(V_c, self.varshape) self.c = braintools.init.param(c, self.varshape) self.a = braintools.init.param(a, self.varshape) self.b = braintools.init.param(b, self.varshape) self.V_initializer = V_initializer self.w_initializer = w_initializer self.ref_var = ref_var
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.w = brainstate.HiddenState(braintools.init.param(self.w_initializer, self.varshape, batch_size)) self.last_spike_time = brainstate.ShortTermState( braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) ) if self.ref_var: self.refractory = brainstate.HiddenState( braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) )
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.w.value = braintools.init.param(self.w_initializer, self.varshape, batch_size) self.last_spike_time.value = braintools.init.param( braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size ) if self.ref_var: self.refractory.value = braintools.init.param( braintools.init.Constant(False), self.varshape, batch_size )
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / (self.V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): t = brainstate.environ.get('t') last_v = self.V.value last_w = self.w.value last_spk = self.get_spike(last_v) V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) v_reset = last_v - (V_th - self.V_reset) * last_spk w_reset = last_w + self.b * last_spk def dv(v): return (self.c * (v - self.V_rest) * (v - self.V_c) - self.R * w_reset + self.R * self.sum_current_inputs(x, v)) / self.tau V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) V_candidate = self.sum_delta_inputs(V_candidate) def dw_func(w_val): return (self.a * (V_candidate - self.V_rest) - w_val) / self.tau_w w_candidate = brainstate.nn.exp_euler_step(dw_func, w_reset) refractory = (t - self.last_spike_time.value) < self.tau_ref self.V.value = u.math.where(refractory, v_reset, V_candidate) self.w.value = u.math.where(refractory, w_reset, w_candidate) spike_cond = self.V.value >= self.V_th self.last_spike_time.value = jax.lax.stop_gradient( u.math.where(spike_cond, t, self.last_spike_time.value) ) if self.ref_var: self.refractory.value = jax.lax.stop_gradient( u.math.logical_or(refractory, spike_cond) ) return self.get_spike() class Gif(Neuron): r"""Generalized Integrate-and-Fire (Gif) neuron model. This model extends the basic integrate-and-fire neuron by adding internal currents and a dynamic threshold. The model can reproduce diverse firing patterns observed in biological neurons. The model is characterized by the following equations: $$ \frac{dI_1}{dt} = -k_1 I_1 $$ $$ \frac{dI_2}{dt} = -k_2 I_2 $$ $$ \tau \frac{dV}{dt} = -(V - V_{rest}) + R(I_1 + I_2 + I(t)) $$ $$ \frac{dV_{th}}{dt} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) $$ When $V \geq V_{th}$: - $I_1 \leftarrow R_1 I_1 + A_1$ - $I_2 \leftarrow R_2 I_2 + A_2$ - $V \leftarrow V_{reset}$ - $V_{th} \leftarrow \max(V_{th_{reset}}, V_{th})$ Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=20. * u.ohm Membrane resistance. tau : ArrayLike, default=20. * u.ms Membrane time constant. V_rest : ArrayLike, default=-70. * u.mV Resting potential. V_reset : ArrayLike, default=-70. * u.mV Reset potential after spike. V_th_inf : ArrayLike, default=-50. * u.mV Target value of threshold potential updating. V_th_reset : ArrayLike, default=-60. * u.mV Free parameter, should be larger than V_reset. V_th_initializer : Callable Initializer for the threshold potential. a : ArrayLike, default=0. / u.ms Coefficient describes dependence of V_th on membrane potential. b : ArrayLike, default=0.01 / u.ms Coefficient describes V_th update. k1 : ArrayLike, default=0.2 / u.ms Constant of I1. k2 : ArrayLike, default=0.02 / u.ms Constant of I2. R1 : ArrayLike, default=0. Free parameter describing dependence of I1 reset value on I1 before spiking. R2 : ArrayLike, default=1. Free parameter describing dependence of I2 reset value on I2 before spiking. A1 : ArrayLike, default=0. * u.mA Free parameter. A2 : ArrayLike, default=0. * u.mA Free parameter. V_initializer : Callable Initializer for the membrane potential state. I1_initializer : Callable Initializer for internal current I1. I2_initializer : Callable Initializer for internal current I2. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function. spk_reset : str, default='soft' Reset mechanism after spike generation. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. I1 : HiddenState Internal current 1. I2 : HiddenState Internal current 2. V_th : HiddenState Spiking threshold potential. See Also -------- GifRef : Gif with absolute refractory period. ALIF : Simpler adaptive LIF model. AdExIF : Adaptive exponential IF (alternative complex model). Notes ----- - The Gif model uses internal currents (I1, I2) for complex dynamics [1]_. - Dynamic threshold V_th adapts based on membrane potential and its own dynamics. - Can reproduce diverse firing patterns: regular spiking, bursting, adaptation. - Parameters a and b control threshold adaptation. - Parameters k1, k2, R1, R2, A1, A2 control internal current dynamics. - More flexible than simpler IF models for matching biological data [2]_. References ---------- .. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear integrate-and-fire neural model produces diverse spiking behaviors." Neural computation 21.3 (2009): 704-718. .. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized leaky integrate-and-fire models classify multiple neuron types." Nature communications 9, no. 1 (2018): 1-15. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create a Gif neuron layer with dynamic threshold >>> gif = brainpy.state.Gif(10, tau=20*u.ms, k1=0.2/u.ms, ... k2=0.02/u.ms, a=0.005/u.ms, b=0.01/u.ms) >>> # Initialize the state >>> gif.init_state(batch_size=1) >>> # Apply input and observe diverse firing patterns >>> spikes = gif.update(x=1.5*u.mA) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 20. * u.ohm, tau: ArrayLike = 20. * u.ms, V_rest: ArrayLike = -70. * u.mV, V_reset: ArrayLike = -70. * u.mV, V_th_inf: ArrayLike = -50. * u.mV, V_th_reset: ArrayLike = -60. * u.mV, V_th_initializer: Callable = braintools.init.Constant(-50. * u.mV), a: ArrayLike = 0. / u.ms, b: ArrayLike = 0.01 / u.ms, k1: ArrayLike = 0.2 / u.ms, k2: ArrayLike = 0.02 / u.ms, R1: ArrayLike = 0., R2: ArrayLike = 1., A1: ArrayLike = 0. * u.mA, A2: ArrayLike = 0. * u.mA, V_initializer: Callable = braintools.init.Constant(-70. * u.mV), I1_initializer: Callable = braintools.init.Constant(0. * u.mA), I2_initializer: Callable = braintools.init.Constant(0. * u.mA), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_th_inf = braintools.init.param(V_th_inf, self.varshape) self.V_th_reset = braintools.init.param(V_th_reset, self.varshape) self.a = braintools.init.param(a, self.varshape) self.b = braintools.init.param(b, self.varshape) self.k1 = braintools.init.param(k1, self.varshape) self.k2 = braintools.init.param(k2, self.varshape) self.R1 = braintools.init.param(R1, self.varshape) self.R2 = braintools.init.param(R2, self.varshape) self.A1 = braintools.init.param(A1, self.varshape) self.A2 = braintools.init.param(A2, self.varshape) self.V_initializer = V_initializer self.I1_initializer = I1_initializer self.I2_initializer = I2_initializer self.V_th_initializer = V_th_initializer
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.I1 = brainstate.HiddenState(braintools.init.param(self.I1_initializer, self.varshape, batch_size)) self.I2 = brainstate.HiddenState(braintools.init.param(self.I2_initializer, self.varshape, batch_size)) self.V_th = brainstate.HiddenState(braintools.init.param(self.V_th_initializer, self.varshape, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.I1.value = braintools.init.param(self.I1_initializer, self.varshape, batch_size) self.I2.value = braintools.init.param(self.I2_initializer, self.varshape, batch_size) self.V_th.value = braintools.init.param(self.V_th_initializer, self.varshape, batch_size)
[docs] def get_spike(self, V: ArrayLike = None, V_th: ArrayLike = None): V = self.V.value if V is None else V V_th = self.V_th.value if V_th is None else V_th v_scaled = (V - V_th) / (V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): last_v = self.V.value last_i1 = self.I1.value last_i2 = self.I2.value last_v_th = self.V_th.value last_spk = self.get_spike(last_v, last_v_th) # Apply spike effects V_th_val = last_v_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) V = last_v - (V_th_val - self.V_reset) * last_spk I1 = last_i1 + last_spk * (self.R1 * last_i1 + self.A1 - last_i1) I2 = last_i2 + last_spk * (self.R2 * last_i2 + self.A2 - last_i2) V_th = last_v_th + last_spk * (u.math.maximum(self.V_th_reset, last_v_th) - last_v_th) # Update dynamics def dI1(i1): return -self.k1 * i1 def dI2(i2): return -self.k2 * i2 def dV_th_func(v_th): return self.a * (V - self.V_rest) - self.b * (v_th - self.V_th_inf) def dv(v): return (-(v - self.V_rest) + self.R * (I1 + I2 + self.sum_current_inputs(x, v))) / self.tau I1 = brainstate.nn.exp_euler_step(dI1, I1) I2 = brainstate.nn.exp_euler_step(dI2, I2) V_th = brainstate.nn.exp_euler_step(dV_th_func, V_th) V = brainstate.nn.exp_euler_step(dv, V) V = self.sum_delta_inputs(V) self.V.value = V self.I1.value = I1 self.I2.value = I2 self.V_th.value = V_th return self.get_spike(V, V_th) class GifRef(Neuron): r"""Generalized Integrate-and-Fire neuron model with refractory mechanism. This model extends Gif by adding an absolute refractory period during which the neuron cannot fire. This creates more realistic firing patterns and prevents unrealistic high-frequency firing. Parameters ---------- in_size : Size Size of the input to the neuron. R : ArrayLike, default=20. * u.ohm Membrane resistance. tau : ArrayLike, default=20. * u.ms Membrane time constant. tau_ref : ArrayLike, default=1.7 * u.ms Absolute refractory period duration. V_rest : ArrayLike, default=-70. * u.mV Resting potential. V_reset : ArrayLike, default=-70. * u.mV Reset potential after spike. V_th_inf : ArrayLike, default=-50. * u.mV Target value of threshold potential updating. V_th_reset : ArrayLike, default=-60. * u.mV Free parameter, should be larger than V_reset. V_th_initializer : Callable Initializer for the threshold potential. a : ArrayLike, default=0. / u.ms Coefficient describes dependence of V_th on membrane potential. b : ArrayLike, default=0.01 / u.ms Coefficient describes V_th update. k1 : ArrayLike, default=0.2 / u.ms Constant of I1. k2 : ArrayLike, default=0.02 / u.ms Constant of I2. R1 : ArrayLike, default=0. Free parameter. R2 : ArrayLike, default=1. Free parameter. A1 : ArrayLike, default=0. * u.mA Free parameter. A2 : ArrayLike, default=0. * u.mA Free parameter. V_initializer : Callable Initializer for the membrane potential state. I1_initializer : Callable Initializer for internal current I1. I2_initializer : Callable Initializer for internal current I2. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function. spk_reset : str, default='soft' Reset mechanism after spike generation. ref_var : bool, default=False Whether to expose a boolean refractory state variable. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. I1 : HiddenState Internal current 1. I2 : HiddenState Internal current 2. V_th : HiddenState Spiking threshold potential. last_spike_time : ShortTermState Last spike time recorder. refractory : HiddenState Neuron refractory state (if ref_var=True). See Also -------- Gif : Gif without refractory period. AdExIFRef : Adaptive exponential IF with refractory period. Notes ----- - Combines Gif model's rich dynamics with absolute refractory period. - During refractory period, all state variables are held at reset values. - Set ref_var=True to track refractory state as a boolean variable. - More biologically realistic than Gif without refractory mechanism. - Can still exhibit diverse firing patterns: regular, bursting, adaptation. - Refractory period prevents unrealistically high firing rates. References ---------- .. [1] Mihalas, S., & Niebur, E. (2009). A generalized linear integrate-and-fire neural model produces diverse spiking behaviors. Neural computation, 21(3), 704-718. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create a GifRef neuron layer with refractory period >>> gif_ref = brainpy.state.GifRef(10, tau=20*u.ms, tau_ref=2.0*u.ms, ... k1=0.2/u.ms, k2=0.02/u.ms, ref_var=True) >>> # Initialize the state >>> gif_ref.init_state(batch_size=1) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, R: ArrayLike = 20. * u.ohm, tau: ArrayLike = 20. * u.ms, tau_ref: ArrayLike = 1.7 * u.ms, V_rest: ArrayLike = -70. * u.mV, V_reset: ArrayLike = -70. * u.mV, V_th_inf: ArrayLike = -50. * u.mV, V_th_reset: ArrayLike = -60. * u.mV, V_th_initializer: Callable = braintools.init.Constant(-50. * u.mV), a: ArrayLike = 0. / u.ms, b: ArrayLike = 0.01 / u.ms, k1: ArrayLike = 0.2 / u.ms, k2: ArrayLike = 0.02 / u.ms, R1: ArrayLike = 0., R2: ArrayLike = 1., A1: ArrayLike = 0. * u.mA, A2: ArrayLike = 0. * u.mA, V_initializer: Callable = braintools.init.Constant(-70. * u.mV), I1_initializer: Callable = braintools.init.Constant(0. * u.mA), I2_initializer: Callable = braintools.init.Constant(0. * u.mA), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'soft', ref_var: bool = False, name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.R = braintools.init.param(R, self.varshape) self.tau = braintools.init.param(tau, self.varshape) self.tau_ref = braintools.init.param(tau_ref, self.varshape) self.V_rest = braintools.init.param(V_rest, self.varshape) self.V_reset = braintools.init.param(V_reset, self.varshape) self.V_th_inf = braintools.init.param(V_th_inf, self.varshape) self.V_th_reset = braintools.init.param(V_th_reset, self.varshape) self.a = braintools.init.param(a, self.varshape) self.b = braintools.init.param(b, self.varshape) self.k1 = braintools.init.param(k1, self.varshape) self.k2 = braintools.init.param(k2, self.varshape) self.R1 = braintools.init.param(R1, self.varshape) self.R2 = braintools.init.param(R2, self.varshape) self.A1 = braintools.init.param(A1, self.varshape) self.A2 = braintools.init.param(A2, self.varshape) self.V_initializer = V_initializer self.I1_initializer = I1_initializer self.I2_initializer = I2_initializer self.V_th_initializer = V_th_initializer self.ref_var = ref_var
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.I1 = brainstate.HiddenState(braintools.init.param(self.I1_initializer, self.varshape, batch_size)) self.I2 = brainstate.HiddenState(braintools.init.param(self.I2_initializer, self.varshape, batch_size)) self.V_th = brainstate.HiddenState(braintools.init.param(self.V_th_initializer, self.varshape, batch_size)) self.last_spike_time = brainstate.ShortTermState( braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) ) if self.ref_var: self.refractory = brainstate.HiddenState( braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) )
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.I1.value = braintools.init.param(self.I1_initializer, self.varshape, batch_size) self.I2.value = braintools.init.param(self.I2_initializer, self.varshape, batch_size) self.V_th.value = braintools.init.param(self.V_th_initializer, self.varshape, batch_size) self.last_spike_time.value = braintools.init.param( braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size ) if self.ref_var: self.refractory.value = braintools.init.param( braintools.init.Constant(False), self.varshape, batch_size )
[docs] def get_spike(self, V: ArrayLike = None, V_th: ArrayLike = None): V = self.V.value if V is None else V V_th = self.V_th.value if V_th is None else V_th v_scaled = (V - V_th) / (V_th - self.V_reset) return self.spk_fun(v_scaled)
def update(self, x=0. * u.mA): t = brainstate.environ.get('t') last_v = self.V.value last_i1 = self.I1.value last_i2 = self.I2.value last_v_th = self.V_th.value last_spk = self.get_spike(last_v, last_v_th) # Apply spike effects V_th_val = last_v_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v) v_reset = last_v - (V_th_val - self.V_reset) * last_spk i1_reset = last_i1 + last_spk * (self.R1 * last_i1 + self.A1 - last_i1) i2_reset = last_i2 + last_spk * (self.R2 * last_i2 + self.A2 - last_i2) v_th_reset = last_v_th + last_spk * (u.math.maximum(self.V_th_reset, last_v_th) - last_v_th) # Update dynamics def dI1(i1): return -self.k1 * i1 def dI2(i2): return -self.k2 * i2 def dV_th_func(v_th): return self.a * (v_reset - self.V_rest) - self.b * (v_th - self.V_th_inf) def dv(v): return (-(v - self.V_rest) + self.R * (i1_reset + i2_reset + self.sum_current_inputs(x, v))) / self.tau I1_candidate = brainstate.nn.exp_euler_step(dI1, i1_reset) I2_candidate = brainstate.nn.exp_euler_step(dI2, i2_reset) V_th_candidate = brainstate.nn.exp_euler_step(dV_th_func, v_th_reset) V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) V_candidate = self.sum_delta_inputs(V_candidate) refractory = (t - self.last_spike_time.value) < self.tau_ref self.V.value = u.math.where(refractory, v_reset, V_candidate) self.I1.value = u.math.where(refractory, i1_reset, I1_candidate) self.I2.value = u.math.where(refractory, i2_reset, I2_candidate) self.V_th.value = u.math.where(refractory, v_th_reset, V_th_candidate) spike_cond = self.V.value >= self.V_th.value self.last_spike_time.value = jax.lax.stop_gradient( u.math.where(spike_cond, t, self.last_spike_time.value) ) if self.ref_var: self.refractory.value = jax.lax.stop_gradient( u.math.logical_or(refractory, spike_cond) ) return self.get_spike()