SoftSign#

class braintools.surrogate.SoftSign(alpha=1.0)#

Judge spiking state with a soft sign function.

This class implements a surrogate gradient using the soft sign function, which provides a smooth approximation to the step function.

The forward function:

\[\begin{split}g(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \\ \end{cases}\end{split}\]

The original function:

\[g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)\]

Backward function:

\[g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}\]
Parameters:

alpha (float, optional) – Parameter controlling the steepness of the surrogate gradient. Higher values make the transition sharper. Default is 1.0.

See also

soft_sign

function version.

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> # Create a soft sign surrogate
>>> ss_fn = braintools.surrogate.SoftSign(alpha=2.0)
>>>
>>> # Apply to input
>>> x = jnp.array([-1.0, 0.0, 1.0])
>>> spikes = ss_fn(x)
>>> print(spikes)  # Binary spike output
>>>
>>> # Use in a spiking layer with adaptive threshold
>>> class AdaptiveSpikingLayer(brainstate.nn.Module):
...     def __init__(self, n_neurons):
...         super().__init__()
...         self.n = n_neurons
...         self.spike_fn = braintools.surrogate.SoftSign(alpha=2.0)
...         self.threshold = jnp.ones(n_neurons)
...
...     def forward(self, membrane_potential):
...         spikes = self.spike_fn(membrane_potential - self.threshold)
...         # Update threshold based on spike history
...         self.threshold += 0.01 * spikes
...         return spikes

Notes

The soft sign function provides gradients that decay more slowly than exponential functions, which can be beneficial for learning in deep networks.

surrogate_fun(x)[source]#

Compute the soft sign surrogate function.

Parameters:

x (jax.Array) – Input tensor.

Returns:

Output of the soft sign function.

Return type:

jax.Array

surrogate_grad(x)[source]#

Compute the gradient of the soft sign function.

Parameters:

x (jax.Array) – Input tensor.

Returns:

Gradient of the soft sign function.

Return type:

jax.Array