import jax
import jax.numpy as jnp
import brainstate
import braintools.surrogate as surrogate
import matplotlib.pyplot as plt
xs = jnp.linspace(-3, 3, 1000)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
mgg_fn = surrogate.MultiGaussianGrad()
grads = jax.vmap(jax.grad(mgg_fn))(xs)
ax1.plot(xs, grads, label='Multi-Gaussian', linewidth=2)
gg_fn = surrogate.GaussianGrad(sigma=0.5, alpha=0.5)
grads_single = jax.vmap(jax.grad(gg_fn))(xs)
ax1.plot(xs, grads_single, '--', label='Single Gaussian', alpha=0.7)
ax1.set_xlabel('Input (x)')
ax1.set_ylabel('Gradient')
ax1.set_title('Multi-Gaussian vs Single Gaussian')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.axhline(y=0, color='k', linestyle='-', linewidth=0.5)
for h in [0.0, 0.15, 0.3, 0.5]:
    mgg_fn = surrogate.MultiGaussianGrad(h=h, s=6.0, sigma=0.5, scale=0.5)
    grads = jax.vmap(jax.grad(mgg_fn))(xs)
    ax2.plot(xs, grads, label=rf'$h={h}$')
ax2.set_xlabel('Input (x)')
ax2.set_ylabel('Gradient')
ax2.set_title('Effect of h Parameter')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.axhline(y=0, color='k', linestyle='-', linewidth=0.5)
plt.tight_layout()
plt.show()
