import jax
import brainstate as brainstate
import matplotlib.pyplot as plt
xs = jax.numpy.linspace(-3, 3, 1000)
for c in [0.01, 0.05, 0.1]:
  for w in [1., 2.]:
    plr_fn = braintools.surrogate.PiecewiseLeakyRelu(c=c, w=w)
    grads = brainstate.augment.vector_grad(plr_fn)(xs)
    plt.plot(xs, grads, label=f'c={c}, w={w}')
plt.legend()
plt.show()
