Neuron#
- class brainpy.state.Neuron(in_size, spk_fun=InvSquareGrad(alpha=100.0), spk_reset='soft', name=None)#
Base class for all spiking neuron models.
This abstract class serves as the foundation for implementing various spiking neuron models in the BrainPy framework. It extends the
brainpy.state.Dynamicsclass and provides common functionality for spike generation, membrane potential dynamics, and surrogate gradient handling required for training spiking neural networks.All neuron models (e.g., IF, LIF, LIFRef, ALIF) should inherit from this class and implement the required abstract methods, particularly
get_spike()which defines the spike generation mechanism.- Parameters:
in_size (
Size) – Size of the input to the neuron layer. Can be an integer for 1D input or a tuple for multi-dimensional input (e.g.,100or(28, 28)).spk_fun (
Callable, optional) –Surrogate gradient function for the non-differentiable spike generation operation. Default is
braintools.surrogate.InvSquareGrad(). Common alternatives include:braintools.surrogate.ReluGrad()braintools.surrogate.SigmoidGrad()braintools.surrogate.GaussianGrad()braintools.surrogate.ATan()
spk_reset (
str, optional) –Reset mechanism applied after spike generation. Default is
'soft'.'soft': Subtract threshold from membrane potential (V = V - V_th). This allows for more biological realism and better gradient flow.'hard': Apply strict reset usingjax.lax.stop_gradientto set voltage to reset value (V = V_reset).
name (
str, optional) – Name identifier for the neuron layer. IfNone, an automatic name will be generated. Useful for debugging and visualization.
- spk_fun#
The surrogate gradient function used for spike generation.
- Type:
Callable
Notes
Surrogate Gradients
The spike generation operation is inherently non-differentiable (a threshold function), which poses challenges for gradient-based learning. Surrogate gradients provide a differentiable approximation during the backward pass while maintaining the discrete spike behavior during the forward pass. This is crucial for training SNNs with backpropagation through time (BPTT).
Reset Mechanisms
Soft Reset: More biologically plausible as it preserves information about how far above threshold the membrane potential was. This can encode information in the residual voltage and often leads to better gradient flow.
Hard Reset: Provides a clean reset to a fixed value, which can be easier to analyze mathematically but may lead to vanishing gradients in deep networks.
State Management
Neuron models typically maintain state variables (e.g., membrane potential
V, adaptation currenta) asbrainstate.HiddenStateobjects. These states are:Initialized via
init_state(batch_size=None, **kwargs)Reset via
reset_state(batch_size=None, **kwargs)Updated via
update(x)which returns spikes for the current timestep
Examples
Creating a Custom Neuron Model
>>> import brainstate >>> import saiunit as u >>> import braintools >>> import brainpy >>> >>> class SimpleNeuron(brainpy.state.Neuron): ... def __init__(self, in_size, V_th=1.0*u.mV, **kwargs): ... super().__init__(in_size, **kwargs) ... self.V_th = V_th ... ... def init_state(self, batch_size=None, **kwargs): ... self.V = brainstate.HiddenState( ... braintools.init.param( ... braintools.init.Constant(0.*u.mV), ... self.varshape, ... batch_size ... ) ... ) ... ... def reset_state(self, batch_size=None, **kwargs): ... self.V.value = braintools.init.param( ... braintools.init.Constant(0.*u.mV), ... self.varshape, ... batch_size ... ) ... ... def get_spike(self, V=None): ... V = self.V.value if V is None else V ... return self.spk_fun((V - self.V_th) / self.V_th) ... ... def update(self, x): ... self.V.value += x ... return self.get_spike()
Using Built-in Neuron Models
>>> import brainpy >>> import brainstate >>> import saiunit as u >>> >>> # Create a LIF neuron layer >>> neuron = brainpy.state.LIF( ... in_size=100, ... tau=10*u.ms, ... V_th=1.0*u.mV, ... spk_fun=braintools.surrogate.ReluGrad(), ... spk_reset='soft' ... ) >>> >>> # Initialize state for batch processing >>> neuron.init_state(batch_size=32) >>> >>> # Process input and get spikes >>> input_current = 2.0 * u.mA >>> spikes = neuron.update(input_current) >>> print(spikes.shape) (32, 100)
Building a Multi-Layer Spiking Network
>>> import brainpy >>> import brainstate >>> import saiunit as u >>> >>> # Create a network with multiple neuron types >>> class SpikingNet(brainstate.nn.Module): ... def __init__(self): ... super().__init__() ... self.layer1 = brainpy.state.LIF(784, tau=5*u.ms) ... self.fc1 = brainstate.nn.Linear(784, 256) ... self.layer2 = brainpy.state.ALIF(256, tau=10*u.ms, tau_a=200*u.ms) ... self.fc2 = brainstate.nn.Linear(256, 10) ... self.layer3 = brainpy.state.LIF(10, tau=8*u.ms) ... ... def __call__(self, x): ... spikes1 = self.layer1.update(x) ... x1 = self.fc1(spikes1) ... spikes2 = self.layer2.update(x1) ... x2 = self.fc2(spikes2) ... spikes3 = self.layer3.update(x2) ... return spikes3
References
- get_spike(*args, **kwargs)[source]#
Generate spikes based on neuron state variables.
This abstract method must be implemented by subclasses to define the spike generation mechanism. The method should use the surrogate gradient function
self.spk_funto enable gradient-based learning.- Parameters:
*args – Positional arguments (typically state variables like membrane potential)
**kwargs – Keyword arguments
- Returns:
Binary spike tensor where 1 indicates a spike and 0 indicates no spike.
- Return type:
ArrayLike- Raises:
NotImplementedError – If the subclass does not implement this method.