SoftSign#
- class braintools.surrogate.SoftSign(alpha=1.0)#
Judge spiking state with a soft sign function.
This class implements a surrogate gradient using the soft sign function, which provides a smooth approximation to the step function.
The forward function:
\[\begin{split}g(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \\ \end{cases}\end{split}\]The original function:
\[g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)\]Backward function:
\[g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}\]- Parameters:
alpha (float, optional) – Parameter controlling the steepness of the surrogate gradient. Higher values make the transition sharper. Default is 1.0.
See also
soft_signfunction version.
Examples
>>> import brainstate >>> import jax.numpy as jnp >>> >>> # Create a soft sign surrogate >>> ss_fn = braintools.surrogate.SoftSign(alpha=2.0) >>> >>> # Apply to input >>> x = jnp.array([-1.0, 0.0, 1.0]) >>> spikes = ss_fn(x) >>> print(spikes) # Binary spike output >>> >>> # Use in a spiking layer with adaptive threshold >>> class AdaptiveSpikingLayer(brainstate.nn.Module): ... def __init__(self, n_neurons): ... super().__init__() ... self.n = n_neurons ... self.spike_fn = braintools.surrogate.SoftSign(alpha=2.0) ... self.threshold = jnp.ones(n_neurons) ... ... def forward(self, membrane_potential): ... spikes = self.spike_fn(membrane_potential - self.threshold) ... # Update threshold based on spike history ... self.threshold += 0.01 * spikes ... return spikes
Notes
The soft sign function provides gradients that decay more slowly than exponential functions, which can be beneficial for learning in deep networks.