import jax
import jax.numpy as jnp
import brainstate
import braintools.surrogate as surrogate
import matplotlib.pyplot as plt
xs = jnp.linspace(-4, 4, 1000)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
for alpha in [0.5, 1.0, 2.0, 4.0]:
    sg_fn = surrogate.SlayerGrad(alpha=alpha)
    grads = jax.vmap(jax.grad(sg_fn))(xs)
    ax1.plot(xs, grads, label=rf'$\alpha={alpha}$')
ax1.set_xlabel('Input (x)')
ax1.set_ylabel('Gradient')
ax1.set_title('SLAYER Surrogate Gradients')
ax1.legend()
ax1.grid(True, alpha=0.3)
xs_pos = jnp.linspace(0, 4, 500)
for alpha in [0.5, 1.0, 2.0, 4.0]:
    sg_fn = surrogate.SlayerGrad(alpha=alpha)
    grads = jax.vmap(jax.grad(sg_fn))(xs_pos)
    ax2.semilogy(xs_pos, grads, label=rf'$\alpha={alpha}$')
for alpha in [1.0, 2.0]:
    theoretical = jnp.exp(-alpha * xs_pos)
    ax2.semilogy(xs_pos, theoretical, '--', alpha=0.5)
ax2.set_xlabel('Distance from threshold')
ax2.set_ylabel('Gradient (log scale)')
ax2.set_title('Exponential Decay Behavior')
ax2.legend()
ax2.grid(True, alpha=0.3, which="both")
plt.tight_layout()
plt.show()
