Source code for brainpy_state._brainpy.izhikevich

# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# -*- coding: utf-8 -*-

from typing import Callable

import brainstate
import braintools
import saiunit as u
import jax
from brainstate.typing import ArrayLike, Size

from brainpy_state._base import Neuron

__all__ = [
    'Izhikevich', 'IzhikevichRef',
]


class Izhikevich(Neuron):
    r"""Izhikevich neuron model.

    This class implements the Izhikevich neuron model, a two-dimensional spiking neuron
    model that can reproduce a wide variety of neuronal firing patterns observed in
    biological neurons. The model combines computational efficiency with biological
    plausibility through a quadratic voltage dynamics and a linear recovery variable.

    The model is characterized by the following differential equations:

    $$
    \frac{dV}{dt} = 0.04 V^2 + 5V + 140 - u + I(t)
    $$

    $$
    \frac{du}{dt} = a(bV - u)
    $$

    Spike condition:
    If $V \geq V_{th}$: emit spike, set $V = c$ and $u = u + d$

    Parameters
    ----------
    in_size : Size
        Size of the input to the neuron.
    a : ArrayLike, default=0.02 / u.ms
        Time scale of the recovery variable u. Smaller values result in slower recovery.
    b : ArrayLike, default=0.2 / u.ms
        Sensitivity of the recovery variable u to the membrane potential V.
    c : ArrayLike, default=-65. * u.mV
        After-spike reset value of the membrane potential.
    d : ArrayLike, default=8. * u.mV / u.ms
        After-spike increment of the recovery variable u.
    V_th : ArrayLike, default=30. * u.mV
        Spike threshold voltage.
    V_initializer : Callable
        Initializer for the membrane potential state.
    u_initializer : Callable
        Initializer for the recovery variable state.
    spk_fun : Callable, default=surrogate.ReluGrad()
        Surrogate gradient function for the non-differentiable spike generation.
    spk_reset : str, default='hard'
        Reset mechanism after spike generation:
        - 'soft': subtract threshold V = V - V_th
        - 'hard': strict reset using stop_gradient
    name : str, optional
        Name of the neuron layer.

    Attributes
    ----------
    V : HiddenState
        Membrane potential.
    u : HiddenState
        Recovery variable.

    See Also
    --------
    IzhikevichRef : Izhikevich model with absolute refractory period.

    Notes
    -----
    - The quadratic term in the voltage equation (0.04*V^2) provides a sharp spike
      upstroke similar to biological neurons [1]_.
    - Different combinations of parameters (a, b, c, d) can reproduce various neuronal
      behaviors including regular spiking, intrinsically bursting, chattering, and
      fast spiking [2]_.
    - The model uses a hard reset mechanism where V is set to c and u is incremented
      by d when a spike occurs.
    - Parameter ranges: a ∈ [0.01, 0.1], b ∈ [0.2, 0.3], c ∈ [-65, -50], d ∈ [0.1, 10]

    References
    ----------
    .. [1] Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions
           on neural networks, 14(6), 1569-1572.
    .. [2] Izhikevich, E. M. (2004). Which model to use for cortical spiking neurons?.
           IEEE transactions on neural networks, 15(5), 1063-1070.

    Examples
    --------
    .. code-block:: python

        >>> import brainpy
        >>> import brainstate
        >>> import saiunit as u
        >>> # Create an Izhikevich neuron layer with 10 neurons
        >>> izh = brainpy.state.Izhikevich(10)
        >>> # Initialize the state
        >>> izh.init_state(batch_size=1)
        >>> # Apply an input current and update the neuron state
        >>> spikes = izh.update(x=10.*u.mV/u.ms)
    """

    __module__ = 'brainpy.state'

    def __init__(
        self,
        in_size: Size,
        a: ArrayLike = 0.02 / u.ms,
        b: ArrayLike = 0.2 / u.ms,
        c: ArrayLike = -65. * u.mV,
        d: ArrayLike = 8. * u.mV / u.ms,
        V_th: ArrayLike = 30. * u.mV,
        V_initializer: Callable = braintools.init.Constant(-65. * u.mV),
        u_initializer: Callable = braintools.init.Constant(0. * u.mV / u.ms),
        spk_fun: Callable = braintools.surrogate.ReluGrad(),
        spk_reset: str = 'hard',
        name: str = None,
    ):
        super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)

        # parameters
        self.a = braintools.init.param(a, self.varshape)
        self.b = braintools.init.param(b, self.varshape)
        self.c = braintools.init.param(c, self.varshape)
        self.d = braintools.init.param(d, self.varshape)
        self.V_th = braintools.init.param(V_th, self.varshape)

        # pre-computed coefficients for quadratic equation
        self.p1 = 0.04 / (u.ms * u.mV)
        self.p2 = 5. / u.ms
        self.p3 = 140. * u.mV / u.ms

        # initializers
        self.V_initializer = V_initializer
        self.u_initializer = u_initializer

[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.u = brainstate.HiddenState(braintools.init.param(self.u_initializer, self.varshape, batch_size))
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.u.value = braintools.init.param(self.u_initializer, self.varshape, batch_size)
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / self.V_th return self.spk_fun(v_scaled)
def update(self, x=0. * u.mV / u.ms): last_v = self.V.value last_u = self.u.value last_spk = self.get_spike(last_v) # Izhikevich uses hard reset: V → c, u → u + d V = u.math.where(last_spk > 0., self.c, last_v) u_val = last_u + self.d * last_spk # voltage dynamics: dV/dt = 0.04*V^2 + 5*V + 140 - u + I def dv(v): I_total = self.sum_current_inputs(x, v) return self.p1 * v * v + self.p2 * v + self.p3 - u_val + I_total # recovery dynamics: du/dt = a(bV - u) def du(u_): return self.a * (self.b * V - u_) V = brainstate.nn.exp_euler_step(dv, V) V = self.sum_delta_inputs(V) u_val = brainstate.nn.exp_euler_step(du, u_val) self.V.value = V self.u.value = u_val return self.get_spike(V) class IzhikevichRef(Neuron): r"""Izhikevich neuron model with refractory period. This class implements the Izhikevich neuron model with an absolute refractory period. During the refractory period after a spike, the neuron cannot fire regardless of input, which better captures the behavior of biological neurons that exhibit a recovery period after action potential generation. The model is characterized by the following equations: When not in refractory period: $$ \frac{dV}{dt} = 0.04 V^2 + 5V + 140 - u + I(t) $$ $$ \frac{du}{dt} = a(bV - u) $$ During refractory period: $$ V = c, \quad u = u $$ Spike condition: If $V \geq V_{th}$ and not in refractory period: emit spike, set $V = c$, $u = u + d$, and enter refractory period for $\tau_{ref}$ Parameters ---------- in_size : Size Size of the input to the neuron. a : ArrayLike, default=0.02 / u.ms Time scale of the recovery variable u. b : ArrayLike, default=0.2 / u.ms Sensitivity of the recovery variable u to the membrane potential V. c : ArrayLike, default=-65. * u.mV After-spike reset value of the membrane potential. d : ArrayLike, default=8. * u.mV / u.ms After-spike increment of the recovery variable u. V_th : ArrayLike, default=30. * u.mV Spike threshold voltage. tau_ref : ArrayLike, default=0. * u.ms Refractory period duration. V_initializer : Callable Initializer for the membrane potential state. u_initializer : Callable Initializer for the recovery variable state. spk_fun : Callable, default=surrogate.ReluGrad() Surrogate gradient function for the non-differentiable spike generation. spk_reset : str, default='hard' Reset mechanism after spike generation. ref_var : bool, default=False Whether to expose a boolean refractory state variable. name : str, optional Name of the neuron layer. Attributes ---------- V : HiddenState Membrane potential. u : HiddenState Recovery variable. last_spike_time : ShortTermState Time of the last spike, used to implement refractory period. refractory : HiddenState Neuron refractory state (if ref_var=True). See Also -------- Izhikevich : Izhikevich model without refractory period. Notes ----- - The refractory period is implemented by tracking the time of the last spike and preventing membrane potential updates if the elapsed time is less than tau_ref. - During the refractory period, the membrane potential remains at the reset value c regardless of input current strength. - Refractory periods prevent high-frequency repetitive firing and are critical for realistic neural dynamics [1]_. - The simulation environment time variable 't' is used to track the refractory state. - For parameter selection guidelines, see [2]_. References ---------- .. [1] Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on neural networks, 14(6), 1569-1572. .. [2] Izhikevich, E. M. (2004). Which model to use for cortical spiking neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. Examples -------- .. code-block:: python >>> import brainpy >>> import brainstate >>> import saiunit as u >>> # Create an IzhikevichRef neuron layer with 10 neurons >>> izh_ref = brainpy.state.IzhikevichRef(10, tau_ref=2.*u.ms) >>> # Initialize the state >>> izh_ref.init_state(batch_size=1) >>> # Generate inputs and run simulation >>> time_steps = 100 >>> inputs = brainstate.random.randn(time_steps, 1, 10) * u.mV / u.ms >>> with brainstate.environ.context(dt=0.1 * u.ms): ... for t in range(time_steps): ... with brainstate.environ.context(t=t*0.1*u.ms): ... spikes = izh_ref.update(x=inputs[t]) """ __module__ = 'brainpy.state' def __init__( self, in_size: Size, a: ArrayLike = 0.02 / u.ms, b: ArrayLike = 0.2 / u.ms, c: ArrayLike = -65. * u.mV, d: ArrayLike = 8. * u.mV / u.ms, V_th: ArrayLike = 30. * u.mV, tau_ref: ArrayLike = 0. * u.ms, V_initializer: Callable = braintools.init.Constant(-65. * u.mV), u_initializer: Callable = braintools.init.Constant(0. * u.mV / u.ms), spk_fun: Callable = braintools.surrogate.ReluGrad(), spk_reset: str = 'hard', ref_var: bool = False, name: str = None, ): super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset) # parameters self.a = braintools.init.param(a, self.varshape) self.b = braintools.init.param(b, self.varshape) self.c = braintools.init.param(c, self.varshape) self.d = braintools.init.param(d, self.varshape) self.V_th = braintools.init.param(V_th, self.varshape) self.tau_ref = braintools.init.param(tau_ref, self.varshape) # pre-computed coefficients for quadratic equation self.p1 = 0.04 / (u.ms * u.mV) self.p2 = 5. / u.ms self.p3 = 140. * u.mV / u.ms # initializers self.V_initializer = V_initializer self.u_initializer = u_initializer self.ref_var = ref_var
[docs] def init_state(self, batch_size: int = None, **kwargs): self.V = brainstate.HiddenState(braintools.init.param(self.V_initializer, self.varshape, batch_size)) self.u = brainstate.HiddenState(braintools.init.param(self.u_initializer, self.varshape, batch_size)) self.last_spike_time = brainstate.ShortTermState( braintools.init.param(braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size) ) if self.ref_var: self.refractory = brainstate.HiddenState( braintools.init.param(braintools.init.Constant(False), self.varshape, batch_size) )
[docs] def reset_state(self, batch_size: int = None, **kwargs): self.V.value = braintools.init.param(self.V_initializer, self.varshape, batch_size) self.u.value = braintools.init.param(self.u_initializer, self.varshape, batch_size) self.last_spike_time.value = braintools.init.param( braintools.init.Constant(-1e7 * u.ms), self.varshape, batch_size ) if self.ref_var: self.refractory.value = braintools.init.param( braintools.init.Constant(False), self.varshape, batch_size )
[docs] def get_spike(self, V: ArrayLike = None): V = self.V.value if V is None else V v_scaled = (V - self.V_th) / self.V_th return self.spk_fun(v_scaled)
def update(self, x=0. * u.mV / u.ms): t = brainstate.environ.get('t') last_v = self.V.value last_u = self.u.value last_spk = self.get_spike(last_v) # Izhikevich uses hard reset: V → c, u → u + d v_reset = u.math.where(last_spk > 0., self.c, last_v) u_reset = last_u + self.d * last_spk # voltage dynamics: dV/dt = 0.04*V^2 + 5*V + 140 - u + I def dv(v): I_total = self.sum_current_inputs(x, v) return self.p1 * v * v + self.p2 * v + self.p3 - u_reset + I_total # recovery dynamics: du/dt = a(bV - u) def du(u_): return self.a * (self.b * V_candidate - u_) V_candidate = brainstate.nn.exp_euler_step(dv, v_reset) V_candidate = self.sum_delta_inputs(V_candidate) u_candidate = brainstate.nn.exp_euler_step(du, u_reset) # apply refractory period refractory = (t - self.last_spike_time.value) < self.tau_ref self.V.value = u.math.where(refractory, v_reset, V_candidate) self.u.value = u.math.where(refractory, u_reset, u_candidate) # spike time evaluation spike_cond = self.V.value >= self.V_th self.last_spike_time.value = jax.lax.stop_gradient( u.math.where(spike_cond, t, self.last_spike_time.value) ) if self.ref_var: self.refractory.value = jax.lax.stop_gradient( u.math.logical_or(refractory, spike_cond) ) return self.get_spike()