Sigmoid#

class braintools.surrogate.Sigmoid(alpha=4.0)#

Spike function with the sigmoid-shaped surrogate gradient.

This class implements a spiking neuron activation with a sigmoid-shaped surrogate gradient for backpropagation. It can be used in spiking neural networks to approximate the non-differentiable step function during training.

Parameters:

alpha (float) – A parameter controlling the steepness of the sigmoid curve in the surrogate gradient. Higher values make the transition sharper. Default is 4.0.

See also

sigmoid

Function version of this class.

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a Sigmoid surrogate gradient function
>>> sigmoid = braintools.surrogate.Sigmoid(alpha=4.0)
>>>
>>> # Apply to input data
>>> x = jnp.array([-1.0, 0.0, 1.0])
>>> spikes = sigmoid(x)
>>> print(spikes)  # Step function output: [0., 1., 1.]
>>>
>>> # Use in a spiking neural network layer
>>> import brainstate.nn as nn
>>>
>>> class SpikingLayer(nn.Module):
...     def __init__(self, in_features, out_features):
...         super().__init__()
...         self.linear = nn.Linear(in_features, out_features)
...         self.spike_fn = braintools.surrogate.Sigmoid(alpha=4.0)
...
...     def forward(self, x):
...         membrane = self.linear(x)
...         return self.spike_fn(membrane)
>>> import jax
>>> import brainstate.nn as nn
>>> import brainstate as brainstate
>>> import matplotlib.pyplot as plt
>>> xs = jax.numpy.linspace(-2, 2, 1000)
>>> for alpha in [1., 2., 4.]:
>>>   sigmoid = braintools.surrogate.Sigmoid(alpha=alpha)
>>>   grads = brainstate.augment.vector_grad(sigmoid)(xs)
>>>   plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
>>> plt.legend()
>>> plt.show()

(Source code)

Notes

The forward pass uses a Heaviside step function (1 for x >= 0, 0 for x < 0), while the backward pass uses a sigmoid-shaped surrogate gradient for smooth optimization. The surrogate gradient is defined as:

\[\begin{split}g'(x) = \\alpha \\cdot (1 - \\sigma(\\alpha x)) \\cdot \\sigma(\\alpha x)\end{split}\]

where \(\\sigma\) is the sigmoid function.

surrogate_fun(x)[source]#

Compute the surrogate function.

Parameters:

x (jax.Array) – The input array.

Returns:

The output of the surrogate function.

Return type:

jax.Array

surrogate_grad(x)[source]#

Compute the gradient of the surrogate function.

Parameters:

x (jax.Array) – The input array.

Returns:

The gradient of the surrogate function.

Return type:

jax.Array