GaussianGrad#

class braintools.surrogate.GaussianGrad(sigma=0.5, alpha=0.5)#

Judge spiking state with the Gaussian gradient function [1].

The Gaussian gradient surrogate provides a smooth, bell-shaped gradient function based on the Gaussian (normal) distribution. This creates a differentiable approximation to the Heaviside step function with continuous derivatives of all orders, making it particularly suitable for gradient-based optimization in spiking neural networks.

The forward function:

\[\begin{split}g(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \\ \end{cases}\end{split}\]

Backward gradient:

\[g'(x) = \alpha \cdot \frac{1}{\sigma\sqrt{2\pi}} \exp\left(-\frac{x^2}{2\sigma^2}\right)\]

where the gradient follows a Gaussian distribution centered at x=0 with:

  • Standard deviation σ controlling the width

  • Scaling factor α controlling the peak height

  • Peak value at x=0: α/(σ√(2π))

>>> import jax
>>> import jax.numpy as jnp
>>> import brainstate
>>> import braintools.surrogate as surrogate
>>> import matplotlib.pyplot as plt
>>>
>>> xs = jnp.linspace(-4, 4, 1000)
>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
>>>
>>> # Plot gradients for different sigma values
>>> alpha = 0.5
>>> for sigma in [0.3, 0.5, 1.0, 2.0]:
>>>     gg_fn = surrogate.GaussianGrad(sigma=sigma, alpha=alpha)
>>>     grads = jax.vmap(jax.grad(gg_fn))(xs)
>>>     ax1.plot(xs, grads, label=rf'$\sigma={sigma}$')
>>>
>>> ax1.set_xlabel('Input (x)')
>>> ax1.set_ylabel('Gradient')
>>> ax1.set_title(f'Gaussian Gradients (α={alpha})')
>>> ax1.legend()
>>> ax1.grid(True, alpha=0.3)
>>>
>>> # Plot gradients for different alpha values
>>> sigma = 0.5
>>> for alpha in [0.25, 0.5, 1.0, 2.0]:
>>>     gg_fn = surrogate.GaussianGrad(sigma=sigma, alpha=alpha)
>>>     grads = jax.vmap(jax.grad(gg_fn))(xs)
>>>     ax2.plot(xs, grads, label=rf'$\alpha={alpha}$')
>>>
>>> ax2.set_xlabel('Input (x)')
>>> ax2.set_ylabel('Gradient')
>>> ax2.set_title(f'Scaling Effect (σ={sigma})')
>>> ax2.legend()
>>> ax2.grid(True, alpha=0.3)
>>> plt.tight_layout()
>>> plt.show()

(Source code, png, hires.png, pdf)

../../_images/braintools-surrogate-GaussianGrad-1.png
Parameters:
  • sigma (float, optional) – Parameter to control the variance (width) of Gaussian distribution. Default is 0.5. Smaller values create sharper gradients, larger values create smoother gradients.

  • alpha (float, optional) – Parameter to control the scale (height) of the gradient. Default is 0.5. Determines the maximum gradient value at x=0.

Examples

>>> import jax
>>> import braintools.surrogate as surrogate
>>>
>>> # Create Gaussian gradient surrogate function
>>> gg_fn = surrogate.GaussianGrad(sigma=0.5, alpha=0.5)
>>>
>>> # Apply to input
>>> x = jax.numpy.array([-1., 0., 1.])
>>> spikes = gg_fn(x)
>>> print(spikes)
[0. 1. 1.]
>>>
>>> # Compute gradients
>>> grad_fn = jax.grad(lambda x: gg_fn(x).sum())
>>> grads = grad_fn(x)
>>> print(f"Gradients: {grads}")

See also

gaussian_grad

Functional version of Gaussian gradient surrogate.

MultiGaussianGrad

Multi-component Gaussian gradient.

Sigmoid

Sigmoid-based surrogate gradient.

References

surrogate_grad(x)[source]#

The gradient function of the surrogate function.