import jax
import brainstate as brainstate
import matplotlib.pyplot as plt
xs = jax.numpy.linspace(-3, 3, 1000)
for alpha in [0.5, 1., 2., 4.]:
  pq_fn = braintools.surrogate.PiecewiseQuadratic(alpha=alpha)
  grads = brainstate.augment.vector_grad(pq_fn)(xs)
  plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
plt.legend()
plt.show()
