Neuron#

class brainpy.state.Neuron(in_size, spk_fun=InvSquareGrad(alpha=100.0), spk_reset='soft', name=None)#

Base class for all spiking neuron models.

This abstract class serves as the foundation for implementing various spiking neuron models in the BrainPy framework. It extends the brainpy.state.Dynamics class and provides common functionality for spike generation, membrane potential dynamics, and surrogate gradient handling required for training spiking neural networks.

All neuron models (e.g., IF, LIF, LIFRef, ALIF) should inherit from this class and implement the required abstract methods, particularly get_spike() which defines the spike generation mechanism.

Parameters:
  • in_size (Size) – Size of the input to the neuron layer. Can be an integer for 1D input or a tuple for multi-dimensional input (e.g., 100 or (28, 28)).

  • spk_fun (Callable, optional) –

    Surrogate gradient function for the non-differentiable spike generation operation. Default is braintools.surrogate.InvSquareGrad(). Common alternatives include:

    • braintools.surrogate.ReluGrad()

    • braintools.surrogate.SigmoidGrad()

    • braintools.surrogate.GaussianGrad()

    • braintools.surrogate.ATan()

  • spk_reset (str, optional) –

    Reset mechanism applied after spike generation. Default is 'soft'.

    • 'soft': Subtract threshold from membrane potential (V = V - V_th). This allows for more biological realism and better gradient flow.

    • 'hard': Apply strict reset using jax.lax.stop_gradient to set voltage to reset value (V = V_reset).

  • name (str, optional) – Name identifier for the neuron layer. If None, an automatic name will be generated. Useful for debugging and visualization.

spk_reset#

The reset mechanism used by the neuron.

Type:

str

spk_fun#

The surrogate gradient function used for spike generation.

Type:

Callable

Notes

Surrogate Gradients

The spike generation operation is inherently non-differentiable (a threshold function), which poses challenges for gradient-based learning. Surrogate gradients provide a differentiable approximation during the backward pass while maintaining the discrete spike behavior during the forward pass. This is crucial for training SNNs with backpropagation through time (BPTT).

Reset Mechanisms

  • Soft Reset: More biologically plausible as it preserves information about how far above threshold the membrane potential was. This can encode information in the residual voltage and often leads to better gradient flow.

  • Hard Reset: Provides a clean reset to a fixed value, which can be easier to analyze mathematically but may lead to vanishing gradients in deep networks.

State Management

Neuron models typically maintain state variables (e.g., membrane potential V, adaptation current a) as brainstate.HiddenState objects. These states are:

  • Initialized via init_state(batch_size=None, **kwargs)

  • Reset via reset_state(batch_size=None, **kwargs)

  • Updated via update(x) which returns spikes for the current timestep

Examples

Creating a Custom Neuron Model

>>> import brainstate
>>> import saiunit as u
>>> import braintools
>>> import brainpy
>>>
>>> class SimpleNeuron(brainpy.state.Neuron):
...     def __init__(self, in_size, V_th=1.0*u.mV, **kwargs):
...         super().__init__(in_size, **kwargs)
...         self.V_th = V_th
...
...     def init_state(self, batch_size=None, **kwargs):
...         self.V = brainstate.HiddenState(
...             braintools.init.param(
...                 braintools.init.Constant(0.*u.mV),
...                 self.varshape,
...                 batch_size
...             )
...         )
...
...     def reset_state(self, batch_size=None, **kwargs):
...         self.V.value = braintools.init.param(
...             braintools.init.Constant(0.*u.mV),
...             self.varshape,
...             batch_size
...         )
...
...     def get_spike(self, V=None):
...         V = self.V.value if V is None else V
...         return self.spk_fun((V - self.V_th) / self.V_th)
...
...     def update(self, x):
...         self.V.value += x
...         return self.get_spike()

Using Built-in Neuron Models

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>>
>>> # Create a LIF neuron layer
>>> neuron = brainpy.state.LIF(
...     in_size=100,
...     tau=10*u.ms,
...     V_th=1.0*u.mV,
...     spk_fun=braintools.surrogate.ReluGrad(),
...     spk_reset='soft'
... )
>>>
>>> # Initialize state for batch processing
>>> neuron.init_state(batch_size=32)
>>>
>>> # Process input and get spikes
>>> input_current = 2.0 * u.mA
>>> spikes = neuron.update(input_current)
>>> print(spikes.shape)
(32, 100)

Building a Multi-Layer Spiking Network

>>> import brainpy
>>> import brainstate
>>> import saiunit as u
>>>
>>> # Create a network with multiple neuron types
>>> class SpikingNet(brainstate.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.layer1 = brainpy.state.LIF(784, tau=5*u.ms)
...         self.fc1 = brainstate.nn.Linear(784, 256)
...         self.layer2 = brainpy.state.ALIF(256, tau=10*u.ms, tau_a=200*u.ms)
...         self.fc2 = brainstate.nn.Linear(256, 10)
...         self.layer3 = brainpy.state.LIF(10, tau=8*u.ms)
...
...     def __call__(self, x):
...         spikes1 = self.layer1.update(x)
...         x1 = self.fc1(spikes1)
...         spikes2 = self.layer2.update(x1)
...         x2 = self.fc2(spikes2)
...         spikes3 = self.layer3.update(x2)
...         return spikes3

References

get_spike(*args, **kwargs)[source]#

Generate spikes based on neuron state variables.

This abstract method must be implemented by subclasses to define the spike generation mechanism. The method should use the surrogate gradient function self.spk_fun to enable gradient-based learning.

Parameters:
  • *args – Positional arguments (typically state variables like membrane potential)

  • **kwargs – Keyword arguments

Returns:

Binary spike tensor where 1 indicates a spike and 0 indicates no spike.

Return type:

ArrayLike

Raises:

NotImplementedError – If the subclass does not implement this method.