import jax
import jax.numpy as jnp
import brainstate
import braintools.surrogate as surrogate
import matplotlib.pyplot as plt
xs = jnp.linspace(-1, 1, 1000)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
for alpha in [10., 50., 100., 200.]:
    isg_fn = surrogate.InvSquareGrad(alpha=alpha)
    grads = jax.vmap(jax.grad(isg_fn))(xs)
    ax1.plot(xs, grads, label=rf'$\alpha={alpha}$')
ax1.set_xlabel('Input (x)')
ax1.set_ylabel('Gradient')
ax1.set_title('Inverse-Square Gradients')
ax1.legend()
ax1.grid(True, alpha=0.3)
xs_wide = jnp.linspace(-3, 3, 1000)
isg_fn = surrogate.InvSquareGrad(alpha=100.)
grads_inv = jax.vmap(jax.grad(isg_fn))(xs_wide)
gg_fn = surrogate.GaussianGrad(sigma=0.1, alpha=1.0)
grads_gauss = jax.vmap(jax.grad(gg_fn))(xs_wide)
ax2.semilogy(xs_wide, jnp.abs(grads_inv), label='Inverse-Square', linewidth=2)
ax2.semilogy(xs_wide, jnp.abs(grads_gauss), '--', label='Gaussian', alpha=0.7)
ax2.set_xlabel('Input (x)')
ax2.set_ylabel('|Gradient| (log scale)')
ax2.set_title('Tail Behavior Comparison')
ax2.legend()
ax2.grid(True, alpha=0.3, which="both")
plt.tight_layout()
plt.show()
