Sigmoid#
- class braintools.surrogate.Sigmoid(alpha=4.0)#
Spike function with the sigmoid-shaped surrogate gradient.
This class implements a spiking neuron activation with a sigmoid-shaped surrogate gradient for backpropagation. It can be used in spiking neural networks to approximate the non-differentiable step function during training.
- Parameters:
alpha (
float) – A parameter controlling the steepness of the sigmoid curve in the surrogate gradient. Higher values make the transition sharper. Default is 4.0.
See also
sigmoidFunction version of this class.
Examples
>>> import brainstate >>> import jax.numpy as jnp >>> >>> # Create a Sigmoid surrogate gradient function >>> sigmoid = braintools.surrogate.Sigmoid(alpha=4.0) >>> >>> # Apply to input data >>> x = jnp.array([-1.0, 0.0, 1.0]) >>> spikes = sigmoid(x) >>> print(spikes) # Step function output: [0., 1., 1.] >>> >>> # Use in a spiking neural network layer >>> import brainstate.nn as nn >>> >>> class SpikingLayer(nn.Module): ... def __init__(self, in_features, out_features): ... super().__init__() ... self.linear = nn.Linear(in_features, out_features) ... self.spike_fn = braintools.surrogate.Sigmoid(alpha=4.0) ... ... def forward(self, x): ... membrane = self.linear(x) ... return self.spike_fn(membrane)
>>> import jax >>> import brainstate.nn as nn >>> import brainstate as brainstate >>> import matplotlib.pyplot as plt >>> xs = jax.numpy.linspace(-2, 2, 1000) >>> for alpha in [1., 2., 4.]: >>> sigmoid = braintools.surrogate.Sigmoid(alpha=alpha) >>> grads = brainstate.augment.vector_grad(sigmoid)(xs) >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) >>> plt.legend() >>> plt.show()
Notes
The forward pass uses a Heaviside step function (1 for x >= 0, 0 for x < 0), while the backward pass uses a sigmoid-shaped surrogate gradient for smooth optimization. The surrogate gradient is defined as:
\[\begin{split}g'(x) = \\alpha \\cdot (1 - \\sigma(\\alpha x)) \\cdot \\sigma(\\alpha x)\end{split}\]where \(\\sigma\) is the sigmoid function.