SlayerGrad#

class braintools.surrogate.SlayerGrad(alpha=1.0)#

Judge spiking state with the slayer surrogate gradient function [1].

The SLAYER (Spike LAYer Error Reassignment) gradient provides an exponential decay surrogate that enables error backpropagation in spiking neural networks. It uses a Laplace-like distribution for the gradient, offering a good balance between gradient magnitude near the threshold and computational efficiency.

The forward function:

\[\begin{split}g(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \\ \end{cases}\end{split}\]

Backward gradient:

\[g'(x) = \exp(-\alpha \cdot |x|)\]

This creates an exponentially decaying gradient with:

  • Peak value of 1 at x=0

  • Exponential decay rate controlled by α

  • Symmetric profile around the threshold

>>> import jax
>>> import jax.numpy as jnp
>>> import brainstate
>>> import braintools.surrogate as surrogate
>>> import matplotlib.pyplot as plt
>>>
>>> xs = jnp.linspace(-4, 4, 1000)
>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
>>>
>>> # Plot gradients for different alpha values
>>> for alpha in [0.5, 1.0, 2.0, 4.0]:
>>>     sg_fn = surrogate.SlayerGrad(alpha=alpha)
>>>     grads = jax.vmap(jax.grad(sg_fn))(xs)
>>>     ax1.plot(xs, grads, label=rf'$\alpha={alpha}$')
>>>
>>> ax1.set_xlabel('Input (x)')
>>> ax1.set_ylabel('Gradient')
>>> ax1.set_title('SLAYER Surrogate Gradients')
>>> ax1.legend()
>>> ax1.grid(True, alpha=0.3)
>>>
>>> # Compare decay rates on semi-log plot
>>> xs_pos = jnp.linspace(0, 4, 500)
>>> for alpha in [0.5, 1.0, 2.0, 4.0]:
>>>     sg_fn = surrogate.SlayerGrad(alpha=alpha)
>>>     grads = jax.vmap(jax.grad(sg_fn))(xs_pos)
>>>     ax2.semilogy(xs_pos, grads, label=rf'$\alpha={alpha}$')
>>>
>>> # Add theoretical exponential decay lines
>>> for alpha in [1.0, 2.0]:
>>>     theoretical = jnp.exp(-alpha * xs_pos)
>>>     ax2.semilogy(xs_pos, theoretical, '--', alpha=0.5)
>>>
>>> ax2.set_xlabel('Distance from threshold')
>>> ax2.set_ylabel('Gradient (log scale)')
>>> ax2.set_title('Exponential Decay Behavior')
>>> ax2.legend()
>>> ax2.grid(True, alpha=0.3, which="both")
>>> plt.tight_layout()
>>> plt.show()

(Source code, png, hires.png, pdf)

../../_images/braintools-surrogate-SlayerGrad-1.png
Parameters:

alpha (float, optional) –

Parameter to control the decay rate of the gradient. Default is 1.0.

  • Larger α creates faster decay (sharper gradients)

  • Smaller α creates slower decay (wider gradients)

  • Decay length scale = 1/α

Examples

>>> import jax
>>> import braintools.surrogate as surrogate
>>>
>>> # Create SLAYER gradient surrogate
>>> sg_fn = surrogate.SlayerGrad(alpha=1.0)
>>>
>>> # Apply to input
>>> x = jax.numpy.array([-2., -1., 0., 1., 2.])
>>> spikes = sg_fn(x)
>>> print(spikes)
[0. 0. 1. 1. 1.]
>>>
>>> # Compute gradients showing exponential decay
>>> grad_fn = jax.grad(lambda x: sg_fn(x).sum())
>>> grads = grad_fn(x)
>>> print(f"Gradients: {grads}")
>>> # Shows exp(-|x|) behavior

See also

slayer_grad

Functional version of SLAYER gradient.

GaussianGrad

Gaussian-based surrogate gradient.

InvSquareGrad

Power-law decay surrogate gradient.

References

surrogate_grad(x)[source]#

The gradient function of the surrogate function.