QPseudoSpike#
- class braintools.surrogate.QPseudoSpike(alpha=2.0)#
Judge spiking state with the q-PseudoSpike surrogate function [1].
The q-PseudoSpike surrogate gradient provides a flexible framework for controlling the tail behavior of the gradient function. The parameter q (represented as alpha in the implementation) controls the tail fatness, allowing for various gradient profiles from heavy-tailed to compact support.
The forward function:
\[\begin{split}g(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \\ \end{cases}\end{split}\]The original function:
\[\begin{split}\begin{split}g_{origin}(x) = \begin{cases} \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. \end{cases}\end{split}\end{split}\]Backward gradient:
\[g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}\]>>> import jax >>> import jax.numpy as jnp >>> import brainstate >>> import braintools.surrogate as surrogate >>> import matplotlib.pyplot as plt >>> >>> xs = jnp.linspace(-3, 3, 1000) >>> fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) >>> >>> # Plot gradients for different alpha values >>> for alpha in [0.5, 1.0, 2.0, 4.0]: >>> qps_fn = surrogate.QPseudoSpike(alpha=alpha) >>> grads = jax.vmap(jax.grad(qps_fn))(xs) >>> ax1.plot(xs, grads, label=rf'$\alpha={alpha}$') >>> >>> ax1.set_xlabel('Input (x)') >>> ax1.set_ylabel('Gradient') >>> ax1.set_title('q-PseudoSpike Surrogate Gradients') >>> ax1.legend() >>> ax1.grid(True, alpha=0.3) >>> ax1.set_ylim([0, 1.2]) >>> >>> # Plot the original function for origin=True >>> for alpha in [1.5, 2.0, 3.0]: >>> qps_fn = surrogate.QPseudoSpike(alpha=alpha) >>> qps_fn.origin = True >>> ys = jax.vmap(qps_fn)(xs) >>> ax2.plot(xs, ys, label=rf'$\alpha={alpha}$') >>> >>> ax2.set_xlabel('Input (x)') >>> ax2.set_ylabel('Output') >>> ax2.set_title('q-PseudoSpike Original Function') >>> ax2.legend() >>> ax2.grid(True, alpha=0.3) >>> plt.tight_layout() >>> plt.show()
(
Source code,png,hires.png,pdf)
- Parameters:
alpha (float, optional) –
Parameter to control tail fatness of gradient. Default is 2.0.
alpha < 1: Heavy-tailed gradient (slower decay)
alpha = 1: Exponential-like decay
alpha > 1: Compact support (faster decay)
alpha = 2: Quadratic decay (default)
Examples
>>> import jax >>> import braintools.surrogate as surrogate >>> >>> # Create q-PseudoSpike surrogate function >>> qps_fn = surrogate.QPseudoSpike(alpha=2.0) >>> >>> # Apply to input >>> x = jax.numpy.array([-1., 0., 1.]) >>> spikes = qps_fn(x) >>> print(spikes) [0. 1. 1.] >>> >>> # Compute gradients with different tail behaviors >>> for alpha in [0.5, 2.0, 4.0]: ... qps_fn = surrogate.QPseudoSpike(alpha=alpha) ... grad_fn = jax.grad(lambda x: qps_fn(x).sum()) ... grads = grad_fn(jax.numpy.array([0.5])) ... print(f"alpha={alpha}: gradient={grads[0]:.4f}")
See also
q_pseudo_spikeFunctional version of q-PseudoSpike surrogate gradient.
SigmoidSigmoid-based surrogate gradient.
S2NNAsymmetric surrogate gradient for single-step networks.
References