# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# -*- coding: utf-8 -*-
import brainstate
import braintools
import saiunit as u
import jax.numpy as jnp
import numpy as np
from brainstate.typing import ArrayLike, Size
from ._base import NESTNeuron
from ._utils import is_tracer
__all__ = [
'ignore_and_fire',
]
class ignore_and_fire(NESTNeuron):
r"""Ignore-and-fire neuron model for generating spikes at fixed intervals.
Description
-----------
The ``ignore_and_fire`` neuron is a neuron model that generates spikes at
a predefined ``rate`` with a constant inter-spike interval ("fire"),
irrespective of its inputs ("ignore"). In this simplest version of the
``ignore_and_fire`` neuron, the inputs from other neurons or devices are
not processed at all.
This is a brainpy.state re-implementation of the NEST simulator model of the
same name, using NEST-standard parameterization.
**1. Model equations and dynamics**
The model's internal state variable, the ``phase``, describes the time to
the next spike relative to the firing period (the inverse of the ``rate``).
The firing period (in simulation time steps) is computed as:
.. math::
T_{\text{fire}} = \text{round}\!\left(\frac{1}{\text{rate}} \times 1000\right) / dt
where rate is in spikes/s and the result is expressed in simulation time steps
(NEST rounds this to the simulation grid via ``Time::get_steps()``).
The initial phase countdown (in simulation time steps) is computed as:
.. math::
N_{\text{phase}} = \text{round}\!\left(\frac{\text{phase}}{\text{rate}}
\times 1000\right) / dt
In each update step, the model checks whether the countdown has reached zero:
- If ``phase_steps == 0``: a spike is emitted and the countdown is reset to
``firing_period_steps - 1``.
- Otherwise: the countdown is decremented by 1.
To create asynchronous activity for a population of ``ignore_and_fire``
neurons, the firing phases can be randomly initialized.
**2. Assumptions, constraints, and computational implications**
- The model assumes unit-compatible parameters and broadcast-compatible
shapes against ``self.varshape``.
- The ``phase`` parameter must satisfy :math:`0 < \text{phase} \le 1`,
representing fractional position within the firing period.
- The ``rate`` parameter must be positive. Extremely low rates (< 1/1000 Hz)
may cause integer overflow when converting to time steps.
- Per-step compute is :math:`O(\prod \mathrm{varshape})` with vectorized
elementwise operations (phase countdown and spike emission).
- All inputs to :meth:`update` are completely ignored; the model fires
deterministically based solely on its internal clock.
.. note::
The ``ignore_and_fire`` neuron is primarily used for neuronal-network
model verification and validation purposes ("benchmarking"), in
particular, to evaluate the correctness and performance of connectivity
generation and inter-neuron communication. It permits an easy scaling
of the network size and/or connectivity without affecting the output
spike statistics. The amount of network traffic is predefined by the
user, and therefore fully controllable and predictable, irrespective
of the network size and structure.
.. note::
This model inherits from :class:`Dynamics` rather than :class:`Neuron`
because it has no membrane potential, no threshold-based spike
generation, and no subthreshold dynamics. Surrogate gradients and
spike reset mechanisms are therefore not applicable.
Parameters
----------
in_size : Size
Population shape specification. All neuron parameters are broadcast to
``self.varshape`` derived from ``in_size``.
phase : ArrayLike, optional
Initial fractional position within the firing period, where 0 means
immediate firing and 1 means firing after a full period. Must satisfy
:math:`0 < \text{phase} \le 1`. Scalar or array broadcast-compatible
with ``varshape``. Unitless. Default is ``1.0``.
rate : ArrayLike, optional
Firing rate in spikes per second. Must be positive. Scalar or array
broadcast-compatible with ``varshape``. Default is ``10. * u.Hz``.
name : str or None, optional
Optional node name passed to :class:`brainstate.nn.Dynamics`.
Parameter Mapping
-----------------
.. list-table:: Parameter mapping to NEST ``ignore_and_fire``
:header-rows: 1
:widths: 22 18 22 38
* - Parameter
- Default
- Math symbol
- Semantics
* - ``phase``
- ``1.0``
- :math:`\phi \in (0, 1]`
- Fractional position in firing period; 1.0 = full period delay.
* - ``rate``
- ``10.0`` Hz
- :math:`f` (Hz)
- Spike rate; firing period :math:`T = 1/f`.
Attributes
----------
phase_steps : ShortTermState
Integer countdown to next spike (in simulation time steps). Decrements
each step; fires when reaching zero.
firing_period_steps : ShortTermState
Integer duration of firing period (in simulation time steps). Constant
after initialization.
Examples
--------
Create an ``ignore_and_fire`` neuron with 10 Hz firing rate:
.. code-block:: python
>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>>
>>> # Create an ignore_and_fire neuron with 10 Hz firing rate
>>> neuron = brainpy.state.ignore_and_fire(1, rate=10.0 * u.Hz)
>>>
>>> # Initialize the state
>>> neuron.init_state()
>>>
>>> # Step the neuron and check for spikes
>>> with brainstate.environ.context(dt=0.1 * u.ms, t=0.0 * u.ms):
... spike = neuron.update()
... print(f"Spike: {spike}")
Create a population with random phases for asynchronous activity:
.. code-block:: python
>>> import numpy as np
>>> # Create 100 neurons with random initial phases
>>> phases = np.random.uniform(0.01, 1.0, size=100)
>>> neurons = brainpy.state.ignore_and_fire(100, phase=phases, rate=20.0 * u.Hz)
References
----------
.. [1] NEST Simulator, ``ignore_and_fire`` model.
https://nest-simulator.readthedocs.io/en/stable/models/ignore_and_fire.html
"""
__module__ = 'brainpy.state'
[docs]
def __init__(
self,
in_size: Size,
phase: ArrayLike = 1.0,
rate: ArrayLike = 10. * u.Hz,
name: str = None,
):
r"""Initialize the ignore_and_fire neuron model.
Stores parameters, initializes base :class:`Dynamics` state, and
validates parameter constraints. Does not initialize internal state
variables (``phase_steps``, ``firing_period_steps``); call
:meth:`init_state` before simulation.
Parameters
----------
in_size : Size
Population shape specification passed to
:class:`brainstate.nn.Dynamics`. Determines ``self.varshape``.
phase : ArrayLike, optional
Initial fractional position within the firing period. Must satisfy
:math:`0 < \text{phase} \le 1`. Broadcast to ``varshape`` via
:func:`braintools.init.param`. Default is ``1.0``.
rate : ArrayLike, optional
Firing rate in spikes per second. Must be positive. Broadcast to
``varshape`` via :func:`braintools.init.param`. Default is
``10. * u.Hz``.
name : str or None, optional
Optional node name passed to :class:`brainstate.nn.Dynamics`.
Raises
------
ValueError
If ``phase`` violates :math:`0 < \text{phase} \le 1` or ``rate``
is not strictly positive, raised by :meth:`_validate_parameters`.
See Also
--------
init_state : Initialize state variables before simulation.
update : Perform one simulation time step.
"""
super().__init__(in_size, name=name)
# Store parameters
self.phase = braintools.init.param(phase, self.varshape)
self.rate = braintools.init.param(rate, self.varshape)
# Validate parameters
self._validate_parameters()
def _validate_parameters(self):
r"""Validate ``phase`` and ``rate`` parameters after initialization.
Ensures that:
- ``phase`` satisfies :math:`0 < \text{phase} \le 1` element-wise.
- ``rate`` is strictly positive element-wise.
Called automatically from :meth:`__init__` after parameter storage.
Raises
------
ValueError
If any element of ``phase`` is :math:`\le 0` or :math:`> 1`, with
message "Phase must be > 0 and <= 1."
ValueError
If any element of ``rate`` (after unit conversion to Hz) is
:math:`\le 0`, with message "Firing rate must be > 0."
Notes
-----
Unit conversion is applied to ``rate`` via
:func:`saiunit.get_magnitude` before validation. ``phase`` is
validated as a unitless scalar or array.
"""
# Skip validation when parameters are JAX tracers (e.g. during jit).
if any(is_tracer(v) for v in (self.phase, self.rate)):
return
phase_val = np.asarray(self.phase)
if np.any(phase_val <= 0.0) or np.any(phase_val > 1.0):
raise ValueError("Phase must be > 0 and <= 1.")
rate_val = np.asarray(u.get_magnitude(self.rate))
if np.any(rate_val <= 0.0):
raise ValueError("Firing rate must be > 0.")
def _calc_initial_variables(self, batch_size=None):
r"""Compute firing period and phase countdown in simulation time steps.
This method replicates NEST's ``calc_initial_variables_`` semantics,
which converts rate and phase parameters to integer step counts using
``Time::get_steps()`` (nearest-integer rounding to the simulation grid).
The conversion formulas are:
.. math::
T_{\text{fire}} &= \text{round}\!\left(\frac{1000}{\text{rate}} / dt\right) \\
N_{\text{phase}} &= \text{round}\!\left(\frac{1000 \cdot \text{phase}}{\text{rate}} / dt\right)
where ``rate`` is in Hz, ``dt`` is in ms, and both results are integer
step counts.
Parameters
----------
batch_size : int or None, optional
Batch dimension size; not currently used by the implementation but
accepted for API compatibility with :meth:`init_state`.
Returns
-------
firing_period_steps : ndarray
Integer array with shape ``varshape`` containing firing period
durations in simulation time steps. Dtype is ``int32``.
phase_steps : ndarray
Integer array with shape ``varshape`` containing initial countdown
values in simulation time steps. Dtype is ``int32``.
Notes
-----
NEST computes these as:
.. code-block:: cpp
firing_period_steps = Time(Time::ms(1.0 / rate * 1000.0)).get_steps()
phase_steps = Time(Time::ms(phase / rate * 1000.0)).get_steps()
``Time::get_steps()`` rounds to the nearest simulation time step. We
replicate this by computing the period/phase in ms, dividing by ``dt``,
then rounding to the nearest integer via ``np.rint``.
"""
dt = brainstate.environ.get_dt()
dt_ms = u.get_magnitude(u.maybe_decimal(dt / u.ms))
rate_hz = u.get_magnitude(u.maybe_decimal(self.rate / u.Hz))
phase_val = np.asarray(self.phase)
# period in ms = 1/rate * 1000
period_ms = 1.0 / rate_hz * 1000.0
# NEST uses Time(Time::ms(...)).get_steps() which rounds to nearest step
firing_period_steps = np.rint(period_ms / dt_ms).astype(np.int32)
# phase time in ms = phase/rate * 1000
phase_ms = phase_val / rate_hz * 1000.0
phase_steps = np.rint(phase_ms / dt_ms).astype(np.int32)
return firing_period_steps, phase_steps
[docs]
def init_state(self, batch_size=None, **kwargs):
r"""Initialize internal state variables for simulation.
Computes and stores ``firing_period_steps`` and ``phase_steps`` as
:class:`brainstate.ShortTermState` arrays. Both are derived from
the ``rate`` and ``phase`` parameters via
:meth:`_calc_initial_variables`.
Parameters
----------
batch_size : int or None, optional
If provided, states are created with shape
``(batch_size, *varshape)``. ``None`` keeps unbatched state.
Default is ``None``.
**kwargs
Unused compatibility parameters accepted by the base-state API.
Raises
------
ValueError
If ``phase`` is not in :math:`(0, 1]` or ``rate`` is not positive,
raised during :meth:`_validate_parameters` called in
:meth:`__init__`.
Side Effects
------------
Creates or overwrites the following instance attributes:
- ``self.firing_period_steps`` : :class:`brainstate.ShortTermState`
- ``self.phase_steps`` : :class:`brainstate.ShortTermState`
Notes
-----
This method must be called before the first :meth:`update` call,
typically via ``neuron.init_state()`` or automatically through
higher-level APIs like ``brainstate.nn.Module.init_all_states()``.
"""
firing_period_steps, phase_steps = self._calc_initial_variables()
ditype = brainstate.environ.ditype()
fps_arr = jnp.asarray(firing_period_steps, dtype=ditype)
ps_arr = jnp.asarray(phase_steps, dtype=ditype)
if batch_size is not None:
batch_shape = (batch_size,) + tuple(self.varshape)
fps_arr = jnp.broadcast_to(fps_arr, batch_shape)
ps_arr = jnp.broadcast_to(ps_arr, batch_shape)
self.firing_period_steps = brainstate.ShortTermState(fps_arr)
self.phase_steps = brainstate.ShortTermState(ps_arr)
[docs]
def update(self, x=None):
r"""Update the ignore_and_fire neuron for one simulation time step.
Decrements the internal phase countdown and emits spikes when the
countdown reaches zero. All external inputs are completely ignored.
The update logic follows NEST's deterministic firing schedule:
1. Check if ``phase_steps == 0``:
- If yes: emit spike (output 1.0), reset countdown to
``firing_period_steps - 1``.
- If no: emit no spike (output 0.0), decrement countdown by 1.
2. Update ``self.phase_steps`` with the new countdown value.
Parameters
----------
x : ArrayLike or None, optional
Input signal (ignored). Accepted for API compatibility with other
neuron models but has no effect on the dynamics. Any value or shape
is permitted; the parameter is never accessed.
Returns
-------
spike : jnp.ndarray
Float array with shape ``varshape`` (or ``(batch_size, *varshape)``
if initialized with batching). Contains ``1.0`` at positions where
a spike is emitted this step and ``0.0`` elsewhere. Dtype is
``float32`` or the default JAX float dtype.
Notes
-----
This method must be called after :meth:`init_state`. Calling
:meth:`update` before initialization will raise an
:class:`AttributeError` due to missing ``phase_steps`` state.
The spike output is suitable for direct use as delta-synapse input
(units of spikes/step) or as a binary event indicator for recording.
"""
phase_steps = self.phase_steps.value
firing_period_steps = self.firing_period_steps.value
# Threshold crossing: phase_steps == 0 means fire
spike = jnp.where(phase_steps == 0, 1.0, 0.0)
# Update phase_steps:
# if fired (phase_steps == 0): reset to firing_period_steps - 1
# else: decrement by 1
new_phase_steps = jnp.where(
phase_steps == 0,
firing_period_steps - 1,
phase_steps - 1,
)
self.phase_steps.value = new_phase_steps
return spike