ReluGrad#

class braintools.surrogate.ReluGrad(alpha=0.3, width=1.0)#

Judge spiking state with the ReLU gradient function [1].

The ReLU gradient surrogate provides a triangular-shaped gradient function with finite support. It creates a linear decrease from the center to the edges, providing a simple and computationally efficient gradient that is non-zero only within a specified width around zero.

The forward function:

\[\begin{split}g(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \\ \end{cases}\end{split}\]

Backward gradient:

\[g'(x) = \text{ReLU}(\alpha * (\text{width} - |x|)) = \max(0, \alpha * (\text{width} - |x|))\]

This creates a triangular gradient centered at x=0 with:

  • Peak value: α × width at x=0

  • Linear decrease to 0 at x=±width

  • Zero gradient for |x| > width

>>> import jax
>>> import jax.numpy as jnp
>>> import brainstate
>>> import braintools.surrogate as surrogate
>>> import matplotlib.pyplot as plt
>>>
>>> xs = jnp.linspace(-3, 3, 1000)
>>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
>>>
>>> # Plot gradients for different parameter combinations
>>> for alpha in [0.3, 0.5, 1.0]:
>>>     for width in [1.0, 2.0]:
>>>         rg_fn = surrogate.ReluGrad(alpha=alpha, width=width)
>>>         grads = jax.vmap(jax.grad(rg_fn))(xs)
>>>         ax1.plot(xs, grads, label=rf'$\alpha={alpha}, w={width}$')
>>>
>>> ax1.set_xlabel('Input (x)')
>>> ax1.set_ylabel('Gradient')
>>> ax1.set_title('ReLU Surrogate Gradients')
>>> ax1.legend()
>>> ax1.grid(True, alpha=0.3)
>>>
>>> # Show effect of width parameter
>>> alpha_fixed = 0.5
>>> for width in [0.5, 1.0, 1.5, 2.0]:
>>>     rg_fn = surrogate.ReluGrad(alpha=alpha_fixed, width=width)
>>>     grads = jax.vmap(jax.grad(rg_fn))(xs)
>>>     ax2.plot(xs, grads, label=rf'$width={width}$')
>>>
>>> ax2.set_xlabel('Input (x)')
>>> ax2.set_ylabel('Gradient')
>>> ax2.set_title(f'Width Effect (α={alpha_fixed})')
>>> ax2.legend()
>>> ax2.grid(True, alpha=0.3)
>>> plt.tight_layout()
>>> plt.show()

(Source code, png, hires.png, pdf)

../../_images/braintools-surrogate-ReluGrad-1.png
Parameters:
  • alpha (float, optional) – Parameter to control the gradient magnitude. Default is 0.3. The peak gradient value is alpha × width.

  • width (float, optional) – Parameter to control the width of the gradient support. Default is 1.0. Gradient is non-zero only for |x| < width.

Examples

>>> import jax
>>> import braintools.surrogate as surrogate
>>>
>>> # Create ReLU gradient surrogate function
>>> rg_fn = surrogate.ReluGrad(alpha=0.3, width=1.0)
>>>
>>> # Apply to input
>>> x = jax.numpy.array([-2., -0.5, 0., 0.5, 2.])
>>> spikes = rg_fn(x)
>>> print(spikes)
[0. 0. 1. 1. 1.]
>>>
>>> # Compute gradients
>>> grad_fn = jax.grad(lambda x: rg_fn(x).sum())
>>> grads = grad_fn(x)
>>> print(grads)  # Shows 0, 0.15, 0.3, 0.15, 0

See also

relu_grad

Functional version of ReLU gradient surrogate.

LeakyRelu

Leaky ReLU surrogate gradient.

PiecewiseLinear

General piecewise linear surrogate gradient.

References

surrogate_grad(x)[source]#

The gradient function of the surrogate function.