# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
from typing import Callable
import brainstate
import braintools
import saiunit as u
import jax
import jax.numpy as jnp
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
__all__ = [
'izhikevich',
]
class izhikevich(NESTNeuron):
r"""Izhikevich neuron model (NEST-compatible).
This model is a brainpy.state re-implementation of the NEST simulator
``izhikevich`` model, using NEST-standard parameterization. It implements
the simple spiking neuron model introduced by Izhikevich [1]_, which
reproduces spiking and bursting behavior of known types of cortical neurons
through a two-dimensional system of ordinary differential equations.
**1. Mathematical Formulation**
The model is defined by the following coupled differential equations:
.. math::
\frac{dV_{\text{m}}}{dt} = 0.04\, V_{\text{m}}^2 + 5\, V_{\text{m}}
+ 140 - U_{\text{m}} + I_{\text{e}}
.. math::
\frac{dU_{\text{m}}}{dt} = a\,(b\, V_{\text{m}} - U_{\text{m}})
where:
- :math:`V_{\text{m}}` is the membrane potential (mV)
- :math:`U_{\text{m}}` is the recovery variable (mV), representing the
combined effects of sodium channel inactivation and potassium channel
activation
- :math:`I_{\text{e}}` is the total input current (pA): external constant
current plus synaptic current
- :math:`a` is the time scale of the recovery variable (dimensionless)
- :math:`b` describes the sensitivity of :math:`U_{\text{m}}` to
subthreshold fluctuations of :math:`V_{\text{m}}` (dimensionless)
**2. Spike Emission and Reset**
A spike is emitted when :math:`V_{\text{m}}` reaches the threshold
:math:`V_{\text{th}}`. At this point the state variables undergo an
instantaneous reset:
.. math::
&\text{if}\; V_m \geq V_{th}:\\
&\quad V_m \leftarrow c\\
&\quad U_m \leftarrow U_m + d
where:
- :math:`c` is the after-spike reset value for :math:`V_{\text{m}}` (mV)
- :math:`d` is the after-spike increment of :math:`U_{\text{m}}` (mV)
Each incoming spike adds to :math:`V_{\text{m}}` by the synaptic weight
associated with the spike (delta-coupling, instantaneous PSC).
**3. Integration Scheme**
This model offers two forms of Euler integration, selected by the boolean
parameter ``consistent_integration``:
- **Standard forward Euler** (``consistent_integration = True``, default):
Both :math:`V_{\text{m}}` and :math:`U_{\text{m}}` are updated based on
their values at the *beginning* of the time step:
.. math::
V_{n+1} &= V_n + h \cdot f(V_n, U_n, I_n) + \Delta V_{\text{syn}}\\
U_{n+1} &= U_n + h \cdot a \cdot (b \cdot V_n - U_n)
where :math:`h` is the time step and :math:`\Delta V_{\text{syn}}` is
the delta synaptic input.
- **Published Izhikevich (2003) numerics** (``consistent_integration =
False``): The membrane potential is updated in two half-steps of size
:math:`h/2`, and the recovery variable uses the *updated*
:math:`V_{\text{m}}`:
.. math::
V_{\text{mid}} &= V_n + \frac{h}{2} \cdot f(V_n, U_n, I_n)\\
V_{n+1} &= V_{\text{mid}} + \frac{h}{2} \cdot f(V_{\text{mid}}, U_n, I_n)\\
U_{n+1} &= U_n + h \cdot a \cdot (b \cdot V_{n+1} - U_n)
This scheme is recommended only for replicating published results and
requires :math:`h = 1.0\,\text{ms}` for consistency with the original
paper. For a detailed analysis of the numerical differences, see [2]_.
**4. Synaptic Input**
Synaptic input enters via two channels:
- **Spike (delta) input** — delivered through ``add_delta_input()`` or the
``delta`` keyword; added directly to :math:`V_{\text{m}}` at the
integration step as an instantaneous voltage jump.
- **Current input** — delivered through the ``x`` argument of
:meth:`update`. Following NEST ring-buffer semantics, the current
applied at simulation step *k* takes effect at step *k + 1* (one-step
delay). This is stored in the ``I`` state variable.
**5. Physical Units and Numerical Assumptions**
The original Izhikevich model uses dimensionless equations with implicit
units. This implementation follows NEST conventions:
- Membrane potential :math:`V_{\text{m}}` in mV
- Recovery variable :math:`U_{\text{m}}` in mV
- Input current :math:`I_{\text{e}}` in pA (with implicit resistance R=1)
- Time constants :math:`a`, :math:`b` are dimensionless
- Time step :math:`h` in ms
The coefficients 0.04 and 5 in the voltage equation have implicit units
that make the equation dimensionally consistent when :math:`V_{\text{m}}`
is in mV and time in ms.
**6. Computational Considerations**
- The quadratic voltage term can lead to numerical instability if the time
step is too large. Use :math:`h \leq 1.0\,\text{ms}` for stability.
- The ``V_min`` parameter prevents unphysical negative voltage divergence.
- The model uses ``float64`` precision internally for all integration
steps to match NEST numerical accuracy.
Parameters
----------
in_size : int, tuple of int
Number of neurons or shape of the neuron population. Determines the
shape of all state variables and parameters (``varshape``).
a : float, array_like, optional
Time scale of the recovery variable :math:`U_{\text{m}}`.
Dimensionless. Default: 0.02.
Typical values: 0.02 (regular spiking), 0.1 (fast spiking).
b : float, array_like, optional
Sensitivity of :math:`U_{\text{m}}` to subthreshold fluctuations of
:math:`V_{\text{m}}`. Dimensionless. Default: 0.2.
Typical values: 0.2 (regular spiking), 0.25 (chattering).
c : Quantity (voltage), array_like, optional
After-spike reset value of :math:`V_{\text{m}}`. Default: -65 mV.
Typical values: -65 mV (regular spiking), -50 mV (chattering).
d : Quantity (voltage), array_like, optional
After-spike increment of :math:`U_{\text{m}}`. Default: 8 mV.
Typical values: 8 mV (regular spiking), 2 mV (fast spiking).
I_e : Quantity (current), array_like, optional
Constant external input current. Default: 0 pA.
Positive values provide tonic excitation.
V_th : Quantity (voltage), array_like, optional
Spike threshold voltage. Default: 30 mV.
NEST uses 30 mV as the practical threshold for the Izhikevich model.
V_min : Quantity (voltage), array_like, optional
Absolute lower bound for :math:`V_{\text{m}}`. Default: None (no bound).
When set, prevents unphysical negative voltage divergence.
Typical value: -100 mV.
consistent_integration : bool, optional
Integration scheme selector. Default: True.
- True: standard forward Euler (recommended).
- False: published Izhikevich (2003) half-step numerics (requires dt=1ms).
V_initializer : callable, optional
Initialization function for :math:`V_{\text{m}}`.
Default: ``Constant(-65 mV)``.
Must accept ``(shape, batch_size)`` and return voltage values.
U_initializer : callable, optional
Initialization function for :math:`U_{\text{m}}`.
Default: None (uses :math:`U_0 = b \cdot V_0`, matching NEST).
Must accept ``(shape, batch_size)`` and return voltage values.
spk_fun : callable, optional
Surrogate gradient function for differentiable spike generation.
Default: ``ReluGrad()``.
Must map ``(V - V_th) / scale`` to [0, 1] with defined gradient.
spk_reset : str, optional
Spike reset mode. Default: 'hard'.
- 'hard': stop gradient at reset (matches NEST dynamics).
- 'soft': allow gradient flow through reset.
name : str, optional
Name of the neuron population. Default: None.
Parameter Mapping
-----------------
========================== ========================== ============================== =================================================================
**NEST Parameter** **brainpy.state** **Math Equivalent** **Description**
========================== ========================== ============================== =================================================================
``a`` ``a`` :math:`a` Time scale of recovery variable :math:`U_{\text{m}}`
``b`` ``b`` :math:`b` Sensitivity of :math:`U_{\text{m}}` to :math:`V_{\text{m}}`
``c`` ``c`` :math:`c` After-spike reset value of :math:`V_{\text{m}}` (mV)
``d`` ``d`` :math:`d` After-spike increment of :math:`U_{\text{m}}` (mV)
``I_e`` ``I_e`` :math:`I_{\text{e}}` Constant input current (pA)
``V_th`` ``V_th`` :math:`V_{\text{th}}` Spike threshold (mV)
``V_min`` ``V_min`` :math:`V_{\text{min}}` Lower bound for :math:`V_{\text{m}}` (mV, optional)
``consistent_integration`` ``consistent_integration`` -- Forward Euler (True) vs. published numerics (False)
========================== ========================== ============================== =================================================================
Attributes
----------
V : HiddenState
Membrane potential :math:`V_{\text{m}}` in mV. Shape: ``(*varshape,)``
or ``(batch_size, *varshape)``.
U : HiddenState
Recovery variable :math:`U_{\text{m}}` in mV. Shape: ``(*varshape,)``
or ``(batch_size, *varshape)``.
I : ShortTermState
Buffered input current from the previous time step in pA (one-step
delayed ring buffer, matching NEST semantics). Shape: ``(*varshape,)``
or ``(batch_size, *varshape)``.
Examples
--------
**Example 1: Regular spiking (RS) neuron**
.. code-block:: python
>>> import brainpy.state as bp
>>> import brainstate
>>> import saiunit as u
>>>
>>> # Create a regular spiking neuron
>>> neuron = bp.izhikevich(1, a=0.02, b=0.2, c=-65*u.mV, d=8*u.mV)
>>> neuron.init_state()
>>>
>>> # Simulate with constant input
>>> with brainstate.environ.context(dt=1.0*u.ms):
... spikes = []
... for _ in range(100):
... spk = neuron.update(x=10.0*u.pA)
... spikes.append(spk)
**Example 2: Fast spiking (FS) neuron**
.. code-block:: python
>>> # Create a fast spiking neuron
>>> neuron = bp.izhikevich(1, a=0.1, b=0.2, c=-65*u.mV, d=2*u.mV)
>>> neuron.init_state()
**Example 3: Chattering (CH) neuron**
.. code-block:: python
>>> # Create a chattering neuron
>>> neuron = bp.izhikevich(1, a=0.02, b=0.2, c=-50*u.mV, d=2*u.mV)
>>> neuron.init_state()
**Example 4: Population with heterogeneous parameters**
.. code-block:: python
>>> import jax.numpy as jnp
>>>
>>> # Create 100 neurons with random parameter variation
>>> key = jax.random.PRNGKey(0)
>>> a_vals = jax.random.uniform(key, (100,), minval=0.01, maxval=0.1)
>>> neuron = bp.izhikevich(100, a=a_vals, b=0.2, c=-65*u.mV, d=8*u.mV)
>>> neuron.init_state()
References
----------
.. [1] Izhikevich EM. (2003). Simple model of spiking neurons. IEEE
Transactions on Neural Networks, 14:1569–1572.
DOI: https://doi.org/10.1109/TNN.2003.820440
.. [2] Pauli R, Weidel P, Kunkel S, Morrison A (2018). Reproducing
polychronization: A guide to maximizing the reproducibility of
spiking network models. Frontiers in Neuroinformatics, 12:46.
DOI: https://doi.org/10.3389/fninf.2018.00046
See Also
--------
iaf_psc_delta : Leaky integrate-and-fire with delta-shaped PSCs
iaf_psc_exp : Leaky integrate-and-fire with exponential PSCs
mat2_psc_exp : Multi-timescale adaptive threshold with exponential PSCs
aeif_psc_exp : Adaptive exponential integrate-and-fire model
"""
__module__ = 'brainpy.state'
def __init__(
self,
in_size: Size,
a: ArrayLike = 0.02,
b: ArrayLike = 0.2,
c: ArrayLike = -65. * u.mV,
d: ArrayLike = 8. * u.mV,
I_e: ArrayLike = 0. * u.pA,
V_th: ArrayLike = 30. * u.mV,
V_min: ArrayLike = None,
consistent_integration: bool = True,
V_initializer: Callable = braintools.init.Constant(-65. * u.mV),
U_initializer: Callable = None,
spk_fun: Callable = braintools.surrogate.ReluGrad(),
spk_reset: str = 'hard',
name: str = None,
):
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
# Parameters (broadcast to varshape)
self.a = braintools.init.param(a, self.varshape)
self.b = braintools.init.param(b, self.varshape)
self.c = braintools.init.param(c, self.varshape)
self.d = braintools.init.param(d, self.varshape)
self.I_e = braintools.init.param(I_e, self.varshape)
self.V_th = braintools.init.param(V_th, self.varshape)
self.V_min = V_min
self.consistent_integration = consistent_integration
self.V_initializer = V_initializer
self.U_initializer = U_initializer
[docs]
def init_state(self, batch_size=None, **kwargs):
r"""Initialize state variables for the Izhikevich neuron.
This method initializes the membrane potential :math:`V_{\text{m}}`,
recovery variable :math:`U_{\text{m}}`, and buffered input current
:math:`I`. By default, :math:`V_{\text{m}}` is initialized to -65 mV
and :math:`U_{\text{m}}` is initialized to :math:`b \cdot V_0`
(matching NEST behavior). The buffered current :math:`I` is initialized
to zero.
Parameters
----------
batch_size : int or None, optional
If provided, states are created with shape
``(batch_size, *varshape)``. ``None`` keeps unbatched state.
Default is ``None``.
**kwargs
Unused compatibility parameters accepted by the base-state API.
Notes
-----
- If ``U_initializer`` is None (default), :math:`U_{\text{m}}` is
initialized to :math:`b \cdot V_0` where :math:`V_0` is the initial
value of :math:`V_{\text{m}}`. This matches NEST's default
initialization: ``u_ = b * v_``.
- The buffered current ``I`` is always initialized to zero with units
of pA, implementing NEST's ring buffer semantic (one-step delay).
"""
V = braintools.init.param(self.V_initializer, self.varshape, batch_size)
if self.U_initializer is not None:
U = braintools.init.param(self.U_initializer, self.varshape, batch_size)
else:
# NEST default: u_ = b * v_ (dimensionless b times V in mV)
U = self.b * V
self.V = brainstate.HiddenState(V)
self.U = brainstate.HiddenState(U)
# Buffered input current (one-step delay, matching NEST ring buffer)
batch_shape = ((batch_size,) + tuple(self.varshape)) if batch_size is not None else self.varshape
self.I = brainstate.ShortTermState(u.math.zeros(batch_shape) * u.pA)
[docs]
def get_spike(self, V: ArrayLike = None):
r"""Compute spike output using the surrogate gradient function.
This method applies the surrogate gradient function (``spk_fun``) to
the scaled voltage difference :math:`(V - V_{\text{th}}) / (V_{\text{th}} - c)`,
producing a differentiable spike indicator for gradient-based learning.
The scaling factor normalizes the voltage range to approximately [0, 1]
for typical surrogate functions.
Parameters
----------
V : ArrayLike, optional
Membrane potential to test for spike emission (with units of
voltage, typically mV). Shape: ``(*varshape,)`` or
``(batch_size, *varshape)``.
Default: None (uses ``self.V.value``).
Returns
-------
ArrayLike
Surrogate-differentiable spike indicator. Shape matches input ``V``.
Values are in [0, 1] for typical surrogate functions, with gradients
defined even at the threshold crossing.
Notes
-----
- The scaling uses the voltage reset range :math:`(V_{\text{th}} - c)`
to normalize the input to the surrogate function.
- This method is called automatically by ``update()`` but can also be
used standalone for custom spike detection logic.
- The returned spike indicator is differentiable for gradient-based
training, unlike a hard threshold (``V >= V_th``).
"""
V = self.V.value if V is None else V
v_scaled = (V - self.V_th) / (self.V_th - self.c)
return self.spk_fun(v_scaled)
[docs]
def update(self, x=0. * u.pA):
r"""Advance the neuron state by one simulation step.
This method implements the NEST ``izhikevich::update`` function,
integrating the differential equations for one time step and handling
spike emission and reset. The update follows NEST semantics exactly,
including the one-step delayed ring buffer for current input.
**Update Sequence:**
1. Read current state (:math:`V_{\text{old}}`, :math:`U_{\text{old}}`)
and buffered current :math:`I` from the previous step.
2. Integrate :math:`V_{\text{m}}` and :math:`U_{\text{m}}` using
forward Euler (or published half-step scheme if
``consistent_integration=False``).
3. Add delta (spike) input directly to :math:`V_{\text{m}}`.
4. Apply the lower bound ``V_min`` if specified.
5. Detect threshold crossing (:math:`V \geq V_{\text{th}}`) and apply
reset: :math:`V \leftarrow c`, :math:`U \leftarrow U + d`.
6. Buffer the new external current ``x`` for the next step (one-step
delay, NEST ring-buffer semantic).
7. Return surrogate-differentiable spike output.
**Integration Details:**
- **Standard Euler** (``consistent_integration=True``):
Both :math:`V` and :math:`U` are updated using their values at the
start of the step.
- **Published Izhikevich numerics** (``consistent_integration=False``):
:math:`V` is updated in two half-steps, and :math:`U` uses the final
:math:`V` value.
**Current Input Timing:**
Following NEST conventions, the current ``x`` provided at simulation
step *k* is buffered and takes effect at step *k + 1*. This one-step
delay matches NEST's ring buffer implementation for synaptic and
external currents.
Parameters
----------
x : Quantity (current), array_like, optional
External current input in pA (or compatible current unit).
Shape: scalar, ``(*varshape,)``, or ``(batch_size, *varshape)``.
Default: 0 pA.
This current is buffered and applied at the *next* time step.
Returns
-------
ArrayLike
Surrogate-differentiable spike output for the current time step.
Shape: ``(*varshape,)`` or ``(batch_size, *varshape)``.
Values are in [0, 1] for typical surrogate functions, with defined
gradients for backpropagation.
Notes
-----
- The integration is performed in ``float64`` precision to match NEST
numerical accuracy.
- Units are stripped during integration (NEST uses dimensionless
arithmetic internally) and restored after integration.
- Delta (spike) inputs are summed via ``sum_delta_inputs()`` and added
directly to the membrane potential as an instantaneous voltage jump.
- The spike detection uses the voltage *before* reset (``V_new``) to
compute the surrogate gradient, while the state variables are updated
to their post-reset values (``V_post``, ``U_post``).
- If ``V_min`` is set, it is enforced after integration but before
spike detection and reset.
"""
dt_q = brainstate.environ.get_dt()
dftype = brainstate.environ.dftype()
h = u.math.asarray(dt_q / u.ms, dtype=dftype)
# Read current state
v_old = self.V.value
u_old = self.U.value
I_buf = self.I.value # current from previous step
# Strip units for the integration (NEST uses dimensionless arithmetic
# internally; the quantities are in mV and pA with R=1)
v = u.math.asarray(v_old / u.mV, dtype=dftype)
um = u.math.asarray(u_old / u.mV, dtype=dftype)
I_val = u.math.asarray((I_buf + self.I_e) / u.pA, dtype=dftype)
a = u.math.asarray(self.a, dtype=dftype)
b = u.math.asarray(self.b, dtype=dftype)
# Delta (spike) input — added directly to V
delta_v = self.sum_delta_inputs(u.math.zeros_like(v_old))
delta_v_raw = u.math.asarray(delta_v / u.mV, dtype=dftype)
if self.consistent_integration:
# Standard forward Euler
v_new = v + h * (0.04 * v * v + 5.0 * v + 140.0 - um + I_val) + delta_v_raw
u_new = um + h * a * (b * v - um)
else:
# Published Izhikevich (2003) numerics: two half-step V updates,
# then U update using the *new* V.
I_syn = delta_v_raw
v_new = v + h * 0.5 * (0.04 * v * v + 5.0 * v + 140.0 - um + I_val + I_syn)
v_new = v_new + h * 0.5 * (0.04 * v_new * v_new + 5.0 * v_new + 140.0 - um + I_val + I_syn)
u_new = um + h * a * (b * v_new - um)
# Lower bound on membrane potential
if self.V_min is not None:
v_min = u.math.asarray(self.V_min / u.mV, dtype=dftype)
v_new = jnp.maximum(v_new, v_min)
# Convert back to quantities with units for spike detection
V_new = v_new * u.mV
U_new = u_new * u.mV
# Threshold crossing and reset
spike_cond = V_new >= self.V_th
V_post = u.math.where(spike_cond, self.c, V_new)
U_post = u.math.where(spike_cond, U_new + self.d, U_new)
# Write back state
self.V.value = V_post
self.U.value = U_post
# Buffer external current for the next step (one-step delay)
self.I.value = self.sum_current_inputs(x, V_post)
return self.get_spike(V_new)