import jax
import jax.numpy as jnp
import brainstate
import braintools.surrogate as surrogate
import matplotlib.pyplot as plt
xs = jnp.linspace(-3, 3, 1000)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
for alpha, beta in [(0.0, 1.0), (0.1, 1.0), (0.3, 1.0), (0.1, 0.5)]:
    lr_fn = surrogate.LeakyRelu(alpha=alpha, beta=beta)
    grads = jax.vmap(jax.grad(lr_fn))(xs)
    ax1.plot(xs, grads, label=rf'$\alpha={alpha}, \beta={beta}$')
ax1.set_xlabel('Input (x)')
ax1.set_ylabel('Gradient')
ax1.set_title('Leaky ReLU Surrogate Gradients')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim([-0.1, 1.2])
for alpha, beta in [(0.1, 1.0), (0.3, 1.0), (0.1, 0.5)]:
    lr_fn = surrogate.LeakyRelu(alpha=alpha, beta=beta)
    lr_fn.origin = True
    ys = jax.vmap(lr_fn)(xs)
    ax2.plot(xs, ys, label=rf'$\alpha={alpha}, \beta={beta}$')
ax2.set_xlabel('Input (x)')
ax2.set_ylabel('Output')
ax2.set_title('Leaky ReLU Original Function')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
