import jax
import brainstate as brainstate
import matplotlib.pyplot as plt
xs = jax.numpy.linspace(-3, 3, 1000)
for n in [2, 4, 8]:
  sfs_fn = braintools.surrogate.SquarewaveFourierSeries(n=n)
  grads = brainstate.augment.vector_grad(sfs_fn)(xs)
  plt.plot(xs, grads, label=f'n={n}')
plt.legend()
plt.show()
