GaussianGrad#
- class braintools.surrogate.GaussianGrad(sigma=0.5, alpha=0.5)#
Judge spiking state with the Gaussian gradient function [1].
The Gaussian gradient surrogate provides a smooth, bell-shaped gradient function based on the Gaussian (normal) distribution. This creates a differentiable approximation to the Heaviside step function with continuous derivatives of all orders, making it particularly suitable for gradient-based optimization in spiking neural networks.
The forward function:
\[\begin{split}g(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \\ \end{cases}\end{split}\]Backward gradient:
\[g'(x) = \alpha \cdot \frac{1}{\sigma\sqrt{2\pi}} \exp\left(-\frac{x^2}{2\sigma^2}\right)\]where the gradient follows a Gaussian distribution centered at x=0 with:
Standard deviation σ controlling the width
Scaling factor α controlling the peak height
Peak value at x=0: α/(σ√(2π))
>>> import jax >>> import jax.numpy as jnp >>> import brainstate >>> import braintools.surrogate as surrogate >>> import matplotlib.pyplot as plt >>> >>> xs = jnp.linspace(-4, 4, 1000) >>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) >>> >>> # Plot gradients for different sigma values >>> alpha = 0.5 >>> for sigma in [0.3, 0.5, 1.0, 2.0]: >>> gg_fn = surrogate.GaussianGrad(sigma=sigma, alpha=alpha) >>> grads = jax.vmap(jax.grad(gg_fn))(xs) >>> ax1.plot(xs, grads, label=rf'$\sigma={sigma}$') >>> >>> ax1.set_xlabel('Input (x)') >>> ax1.set_ylabel('Gradient') >>> ax1.set_title(f'Gaussian Gradients (α={alpha})') >>> ax1.legend() >>> ax1.grid(True, alpha=0.3) >>> >>> # Plot gradients for different alpha values >>> sigma = 0.5 >>> for alpha in [0.25, 0.5, 1.0, 2.0]: >>> gg_fn = surrogate.GaussianGrad(sigma=sigma, alpha=alpha) >>> grads = jax.vmap(jax.grad(gg_fn))(xs) >>> ax2.plot(xs, grads, label=rf'$\alpha={alpha}$') >>> >>> ax2.set_xlabel('Input (x)') >>> ax2.set_ylabel('Gradient') >>> ax2.set_title(f'Scaling Effect (σ={sigma})') >>> ax2.legend() >>> ax2.grid(True, alpha=0.3) >>> plt.tight_layout() >>> plt.show()
(
Source code,png,hires.png,pdf)
- Parameters:
sigma (float, optional) – Parameter to control the variance (width) of Gaussian distribution. Default is 0.5. Smaller values create sharper gradients, larger values create smoother gradients.
alpha (float, optional) – Parameter to control the scale (height) of the gradient. Default is 0.5. Determines the maximum gradient value at x=0.
Examples
>>> import jax >>> import braintools.surrogate as surrogate >>> >>> # Create Gaussian gradient surrogate function >>> gg_fn = surrogate.GaussianGrad(sigma=0.5, alpha=0.5) >>> >>> # Apply to input >>> x = jax.numpy.array([-1., 0., 1.]) >>> spikes = gg_fn(x) >>> print(spikes) [0. 1. 1.] >>> >>> # Compute gradients >>> grad_fn = jax.grad(lambda x: gg_fn(x).sum()) >>> grads = grad_fn(x) >>> print(f"Gradients: {grads}")
See also
gaussian_gradFunctional version of Gaussian gradient surrogate.
MultiGaussianGradMulti-component Gaussian gradient.
SigmoidSigmoid-based surrogate gradient.
References