import jax
import brainstate as brainstate
import matplotlib.pyplot as plt
xs = jax.numpy.linspace(-3, 3, 1000)
for alpha in [0.5, 1., 2., 4.]:
  pe_fn = braintools.surrogate.PiecewiseExp(alpha=alpha)
  grads = brainstate.augment.vector_grad(pe_fn)(xs)
  plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
plt.legend()
plt.show()
