import jax
import brainstate.nn as nn
import brainstate as brainstate
import matplotlib.pyplot as plt
xs = jax.numpy.linspace(-2, 2, 1000)
for alpha in [1., 2., 4.]:
  sigmoid = braintools.surrogate.Sigmoid(alpha=alpha)
  grads = brainstate.augment.vector_grad(sigmoid)(xs)
  plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
plt.legend()
plt.show()
