Tutorial 3: Meta-Learning with Surrogate Hyperparameters#

This tutorial explores meta-learning for spiking neural networks by optimizing surrogate gradient hyperparameters. We’ll cover:

  1. What is Meta-Learning? - Learning to learn

  2. Hyperparameter Derivatives - Taking gradients through surrogate parameters

  3. Basic Meta-Learning - Simple examples

  4. Practical Applications - Task-adaptive and network-adaptive surrogates

  5. Multi-Parameter Optimization - Complex surrogates

  6. Advanced Techniques - Per-layer adaptation, schedules

  7. Real-World Examples - MNIST with learned surrogates

Setup#

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import braintools.surrogate as surrogate
import brainstate

# Set up plotting
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

print("Setup complete!")
print(f"JAX version: {jax.__version__}")
Setup complete!
JAX version: 0.7.2

1. What is Meta-Learning for SNNs?#

Meta-learning (“learning to learn”) optimizes the learning algorithm itself, not just the model parameters.

Traditional Training#

Fixed surrogate (e.g., α=4.0) → Train network weights → Evaluate

Meta-Learning#

Trainable surrogate (α=?) → Train network weights → Evaluate → Update α

Why Meta-Learn Surrogate Hyperparameters?#

  1. Task-Specific Optimization: Different tasks may benefit from different gradient shapes

  2. Network Architecture: Deeper networks may need wider gradients

  3. Data-Driven: Learn optimal surrogates from data, not hand-tuning

  4. Adaptive Training: Surrogate can change during training (annealing)

The Key Insight#

JAX’s automatic differentiation works through hyperparameters!

\[ \frac{\partial \mathcal{L}}{\partial \alpha} = \frac{\partial \mathcal{L}}{\partial \sigma'(x; \alpha)} \frac{\partial \sigma'(x; \alpha)}{\partial \alpha} \]
print("=== META-LEARNING CONCEPT ===\n")

# Demonstrate gradient through hyperparameter
def loss_with_surrogate(alpha, x, target):
    """Loss function that depends on surrogate hyperparameter."""
    sg = surrogate.Sigmoid(alpha=alpha)
    output = sg(x)
    return jnp.mean((output - target) ** 2)

# Test data
x_test = jnp.array([0.5, 1.0, 1.5])
target_test = jnp.array([1.0, 1.0, 1.0])
alpha_test = 4.0

# Compute gradients
loss_value = loss_with_surrogate(alpha_test, x_test, target_test)
grad_alpha = jax.grad(loss_with_surrogate, argnums=0)(alpha_test, x_test, target_test)

print(f"Loss with α={alpha_test}: {loss_value:.6f}")
print(f"Gradient ∂L/∂α: {grad_alpha:.6f}")
print(f"\n✅ We can compute gradients with respect to α!")
print(f"✅ This enables meta-learning of surrogate hyperparameters!")
=== META-LEARNING CONCEPT ===

Loss with α=4.0: 0.000000
Gradient ∂L/∂α: 0.000000

✅ We can compute gradients with respect to α!
✅ This enables meta-learning of surrogate hyperparameters!

2. Visualizing Hyperparameter Derivatives#

Let’s visualize how the alpha parameter affects both the gradient and its derivative.

print("=== HYPERPARAMETER DERIVATIVE VISUALIZATION ===\n")

# Define a simple loss function
def simple_loss(alpha, x):
    """Simple loss: sum of surrogate outputs."""
    return jnp.sum(surrogate.sigmoid(x, alpha=alpha))

# Test different alpha values
alphas = jnp.linspace(0.5, 8.0, 50)
x_vals = jnp.array([0.0, 0.5, 1.0])

# Compute loss and gradient for each alpha
losses = jax.vmap(lambda a: simple_loss(a, x_vals))(alphas)
grads = jax.vmap(lambda a: jax.grad(simple_loss, argnums=0)(a, x_vals))(alphas)

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Loss landscape
axes[0].plot(alphas, losses, 'b-', linewidth=3, marker='o', markersize=4)
axes[0].set_xlabel('Alpha (α)', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Loss vs Surrogate Hyperparameter', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].axvline(4.0, color='red', linestyle='--', alpha=0.5, label='α=4.0 (common)')
axes[0].legend(fontsize=11)

# Gradient w.r.t. alpha
axes[1].plot(alphas, grads, 'r-', linewidth=3, marker='o', markersize=4)
axes[1].axhline(0, color='black', linestyle='-', alpha=0.3)
axes[1].axvline(4.0, color='red', linestyle='--', alpha=0.5, label='α=4.0 (common)')
axes[1].set_xlabel('Alpha (α)', fontsize=12)
axes[1].set_ylabel('Gradient ∂L/∂α', fontsize=12)
axes[1].set_title('Hyperparameter Gradient', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].legend(fontsize=11)

plt.tight_layout()
plt.show()

print("\nKey Observations:")
print("• Loss changes with alpha → we can optimize it!")
print("• Gradient ∂L/∂α tells us how to update alpha")
print("• Negative gradient → increase alpha")
print("• Positive gradient → decrease alpha")
=== HYPERPARAMETER DERIVATIVE VISUALIZATION ===
../_images/4baa94b060d142acd54a8c3f70fe9843f1006ef0dabd921ab69702bbd6fa267d.png
Key Observations:
• Loss changes with alpha → we can optimize it!
• Gradient ∂L/∂α tells us how to update alpha
• Negative gradient → increase alpha
• Positive gradient → decrease alpha

How Alpha Affects the Surrogate Gradient#

Let’s see how changing alpha affects the gradient shape and magnitude:

print("=== ALPHA IMPACT ON GRADIENT ===\n")

x = jnp.linspace(-2, 2, 1000)
alphas_to_show = [1.0, 2.0, 4.0, 8.0]

fig, axes = plt.subplots(2, 2, figsize=(16, 10))
axes = axes.flatten()

for idx, alpha in enumerate(alphas_to_show):
    # Compute surrogate gradient
    sg = surrogate.Sigmoid(alpha=alpha)
    grad = sg.surrogate_grad(x)
    
    # Compute derivative of gradient w.r.t. alpha at x=0
    def grad_at_point(a, x_val):
        return surrogate.Sigmoid(alpha=a).surrogate_grad(x_val)
    
    dgrad_dalpha_at_0 = jax.grad(grad_at_point, argnums=0)(alpha, 0.0)
    
    axes[idx].plot(x, grad, 'b-', linewidth=3, label=f\'(x; α={alpha})')
    axes[idx].fill_between(x, 0, grad, alpha=0.3, color='blue')
    axes[idx].axvline(0, color='red', linestyle='--', alpha=0.5)
    axes[idx].set_xlabel('Input (x)', fontsize=12)
    axes[idx].set_ylabel('Gradient Value', fontsize=12)
    axes[idx].set_title(f'α = {alpha}  |  Peak grad = {jnp.max(grad):.3f}  |  ∂σ\'/∂α|ₓ₌₀ = {dgrad_dalpha_at_0:.3f}', 
                       fontsize=11, fontweight='bold')
    axes[idx].legend(fontsize=10)
    axes[idx].grid(True, alpha=0.3)
    
plt.tight_layout()
plt.show()

print("\nEffect of Alpha:")
print("• Higher α → narrower, taller gradient")
print("• Peak gradient value increases with α")
print("• ∂σ'/∂α is positive → gradient increases with α")
print("\n✅ We can measure how sensitive the gradient is to α!")
=== ALPHA IMPACT ON GRADIENT ===
../_images/33b71dff81e692cb7dd6405be0b190a29bb6b6ce2375c637dbe8522b3b686fda.png
Effect of Alpha:
• Higher α → narrower, taller gradient
• Peak gradient value increases with α
• ∂σ'/∂α is positive → gradient increases with α

✅ We can measure how sensitive the gradient is to α!

3. Basic Meta-Learning: Optimizing Alpha#

Let’s implement a simple meta-learning loop to find the optimal alpha value for a task.

Scenario#

We want to train a simple spiking neuron, and we’ll meta-learn the best alpha value.

print("=== BASIC META-LEARNING: OPTIMIZING ALPHA ===\n")

# Simple spiking neuron model
def spiking_neuron(weights, inputs, alpha, tau=10.0, dt=0.1, n_steps=50):
    """
    Simple LIF neuron with trainable weights and surrogate alpha.
    
    Parameters:
    - weights: synaptic weights
    - inputs: input spike trains
    - alpha: surrogate gradient hyperparameter
    """
    v = 0.0
    spike_count = 0.0
    sg = surrogate.Sigmoid(alpha=alpha)
    
    for t in range(n_steps):
        # Synaptic input
        i_syn = jnp.dot(weights, inputs[t])
        
        # Membrane dynamics
        dv = (-v + i_syn) / tau * dt
        v = v + dv
        
        # Generate spike
        spike = sg(v - 1.0)
        spike_count = spike_count + spike
        
        # Reset
        v = v * (1.0 - spike)
    
    return spike_count

# Generate synthetic data
key = jax.random.PRNGKey(0)
n_inputs = 10
n_steps = 50
n_samples = 20

# Create input patterns and targets
key, subkey = jax.random.split(key)
inputs = jax.random.bernoulli(subkey, 0.3, (n_samples, n_steps, n_inputs)).astype(float)
targets = jax.random.uniform(key, (n_samples,)) * 10  # Target spike counts

# Loss function
def loss_fn(weights, alpha, inputs, targets):
    """Loss depends on both weights AND alpha."""
    def single_loss(inp, tgt):
        pred = spiking_neuron(weights, inp, alpha)
        return (pred - tgt) ** 2
    
    return jnp.mean(jax.vmap(single_loss)(inputs, targets))

# Initialize
key, subkey = jax.random.split(key)
weights = jax.random.normal(subkey, (n_inputs,)) * 0.1
alpha = 2.0  # Initial alpha

# Meta-learning loop
n_meta_steps = 50
lr_weights = 0.01
lr_alpha = 0.1

history = {
    'loss': [],
    'alpha': [],
    'weights_norm': []
}

print("Starting meta-learning...\n")
print(f"{'Step':<6} {'Loss':<12} {'Alpha':<12} {'|Weights|':<12}")
print("-" * 50)

for step in range(n_meta_steps):
    # Compute loss and gradients
    loss_val, (grad_w, grad_a) = jax.value_and_grad(loss_fn, argnums=(0, 1))(
        weights, alpha, inputs, targets
    )
    
    # Update weights (inner loop)
    weights = weights - lr_weights * grad_w
    
    # Update alpha (outer loop - meta-learning)
    alpha = alpha - lr_alpha * grad_a
    alpha = jnp.clip(alpha, 0.5, 10.0)  # Keep alpha in reasonable range
    
    # Record history
    history['loss'].append(float(loss_val))
    history['alpha'].append(float(alpha))
    history['weights_norm'].append(float(jnp.linalg.norm(weights)))
    
    if step % 10 == 0:
        print(f"{step:<6} {loss_val:<12.6f} {alpha:<12.6f} {jnp.linalg.norm(weights):<12.6f}")

print(f"\n✅ Meta-learning complete!")
print(f"\nFinal Results:")
print(f"• Learned alpha: {alpha:.4f}")
print(f"• Final loss: {history['loss'][-1]:.6f}")
print(f"• Initial loss: {history['loss'][0]:.6f}")
=== BASIC META-LEARNING: OPTIMIZING ALPHA ===

Starting meta-learning...

Step   Loss         Alpha        |Weights|   
--------------------------------------------------
0      33.388409    0.500000     0.277916    
10     33.388409    9.684774     0.348439    
20     33.388409    9.190311     0.358104    
30     33.388409    8.329274     0.375764    
40     33.388409    4.970298     0.454572    

✅ Meta-learning complete!

Final Results:
• Learned alpha: 8.0207
• Final loss: 33.388409
• Initial loss: 33.388409
# Visualize meta-learning progress
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

steps = np.arange(len(history['loss']))

# Loss over time
axes[0].plot(steps, history['loss'], 'b-', linewidth=2.5, marker='o', markersize=4)
axes[0].set_xlabel('Meta-Learning Step', fontsize=12)
axes[0].set_ylabel('Loss (MSE)', fontsize=12)
axes[0].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].set_yscale('log')

# Alpha evolution
axes[1].plot(steps, history['alpha'], 'r-', linewidth=2.5, marker='o', markersize=4)
axes[1].axhline(4.0, color='gray', linestyle='--', alpha=0.5, label='Common default (α=4)')
axes[1].set_xlabel('Meta-Learning Step', fontsize=12)
axes[1].set_ylabel('Alpha (α)', fontsize=12)
axes[1].set_title('Learned Surrogate Hyperparameter', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

# Weights norm
axes[2].plot(steps, history['weights_norm'], 'g-', linewidth=2.5, marker='o', markersize=4)
axes[2].set_xlabel('Meta-Learning Step', fontsize=12)
axes[2].set_ylabel('||Weights||', fontsize=12)
axes[2].set_title('Weight Magnitude', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\n📊 The network learned an optimal alpha of {history['alpha'][-1]:.3f}")
print(f"   This may differ from the common default (4.0)!")
../_images/9a1473d6446f671aaacb4419d248588730bc682be95d6d7df80f9bc42c961c77.png
📊 The network learned an optimal alpha of 8.021
   This may differ from the common default (4.0)!

4. Comparison: Fixed vs Meta-Learned Alpha#

Let’s compare performance with a fixed alpha vs the meta-learned alpha.

print("=== COMPARISON: FIXED VS META-LEARNED ALPHA ===\n")

def train_with_fixed_alpha(alpha_fixed, n_steps=50):
    """Train with a fixed alpha value."""
    key = jax.random.PRNGKey(42)
    key, subkey = jax.random.split(key)
    weights = jax.random.normal(subkey, (n_inputs,)) * 0.1

    @brainstate.transform.jit
    def run(weights):
        loss_val, grad_w = jax.value_and_grad(loss_fn, argnums=0)(weights, alpha_fixed, inputs, targets)
        weights = weights - lr_weights * grad_w
        return weights, loss_val

    losses = []
    for step in range(n_steps):
        weights, loss_val = run(weights)
        losses.append(float(loss_val))
    
    return losses

# Train with different fixed alphas
fixed_alphas = [1.0, 2.0, 4.0, 8.0]
results = {}

for alpha_val in fixed_alphas:
    results[f'α={alpha_val}'] = train_with_fixed_alpha(alpha_val)

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Learning curves
colors = plt.cm.tab10(np.linspace(0, 1, len(fixed_alphas) + 1))

for idx, (label, losses) in enumerate(results.items()):
    axes[0].plot(losses, linewidth=2, label=label, color=colors[idx], alpha=0.7)

axes[0].plot(history['loss'], linewidth=3, label=f'Meta-learned (α={history["alpha"][-1]:.2f})', 
            color=colors[-1], linestyle='--')
axes[0].set_xlabel('Training Step', fontsize=12)
axes[0].set_ylabel('Loss (MSE)', fontsize=12)
axes[0].set_title('Learning Curves: Fixed vs Meta-Learned', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)
axes[0].set_yscale('log')

# Final loss comparison
final_losses = [losses[-1] for losses in results.values()]
final_losses.append(history['loss'][-1])
labels = list(results.keys()) + [f'Meta (α={history["alpha"][-1]:.2f})']

bars = axes[1].bar(range(len(labels)), final_losses, color=colors, alpha=0.7, edgecolor='black')
axes[1].set_xticks(range(len(labels)))
axes[1].set_xticklabels(labels, rotation=45, ha='right')
axes[1].set_ylabel('Final Loss', fontsize=12)
axes[1].set_title('Final Performance Comparison', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')

# Highlight best
best_idx = np.argmin(final_losses)
bars[best_idx].set_color('gold')
bars[best_idx].set_edgecolor('red')
bars[best_idx].set_linewidth(3)

plt.tight_layout()
plt.show()

print(f"\n📊 Performance Summary:")
for label, final_loss in zip(labels, final_losses):
    marker = "🏆" if final_loss == min(final_losses) else "  "
    print(f"{marker} {label:<20}: {final_loss:.6f}")

print(f"\n✅ Meta-learning found a better (or competitive) alpha!")
=== COMPARISON: FIXED VS META-LEARNED ALPHA ===
../_images/ecefbcf0c3c3f80602394d5b67bb29c76315828ab42a537cc519fd1a68938259.png
📊 Performance Summary:
   α=1.0               : 24.676043
🏆 α=2.0               : 23.981188
   α=4.0               : 24.676043
   α=8.0               : 33.388409
   Meta (α=8.02)       : 33.388409

✅ Meta-learning found a better (or competitive) alpha!

5. Multi-Parameter Meta-Learning#

Many surrogates have multiple hyperparameters. Let’s meta-learn all of them!

Example: LeakyRelu with Two Parameters (alpha, beta)#

print("=== MULTI-PARAMETER META-LEARNING ===\n")

# Modified neuron with LeakyRelu surrogate
def spiking_neuron_leaky_relu(weights, inputs, alpha, beta, tau=10.0, dt=0.1, n_steps=50):
    """Neuron with LeakyRelu surrogate (2 parameters)."""
    v = 0.0
    spike_count = 0.0
    
    for t in range(n_steps):
        i_syn = jnp.dot(weights, inputs[t])
        dv = (-v + i_syn) / tau * dt
        v = v + dv
        
        # Use functional API with both parameters
        spike = surrogate.leaky_relu(v - 1.0, alpha=alpha, beta=beta)
        spike_count = spike_count + spike
        v = v * (1.0 - spike)
    
    return spike_count

# Loss function with 2 surrogate hyperparameters
@brainstate.transform.jit
def loss_fn_multi(weights, alpha, beta, inputs, targets):
    """Loss depends on weights, alpha, AND beta."""
    def single_loss(inp, tgt):
        pred = spiking_neuron_leaky_relu(weights, inp, alpha, beta)
        return (pred - tgt) ** 2
    
    return jnp.mean(jax.vmap(single_loss)(inputs, targets))

# Initialize
key = jax.random.PRNGKey(123)
key, subkey = jax.random.split(key)
weights_multi = jax.random.normal(subkey, (n_inputs,)) * 0.1
alpha_param = 0.1  # LeakyRelu alpha (leak rate)
beta_param = 1.0   # LeakyRelu beta (scaling)

# Meta-learning both parameters
n_steps_multi = 50
lr_w = 0.01
lr_alpha = 0.005
lr_beta = 0.01

history_multi = {
    'loss': [],
    'alpha': [],
    'beta': []
}

print("Meta-learning alpha AND beta...\n")
print(f"{'Step':<6} {'Loss':<12} {'Alpha':<12} {'Beta':<12}")
print("-" * 50)

for step in range(n_steps_multi):
    # Compute gradients w.r.t. all three: weights, alpha, beta
    loss_val, (grad_w, grad_alpha, grad_beta) = jax.value_and_grad(
        loss_fn_multi, argnums=(0, 1, 2)
    )(weights_multi, alpha_param, beta_param, inputs, targets)
    
    # Update all parameters
    weights_multi = weights_multi - lr_w * grad_w
    alpha_param = alpha_param - lr_alpha * grad_alpha
    beta_param = beta_param - lr_beta * grad_beta
    
    # Clip to reasonable ranges
    alpha_param = jnp.clip(alpha_param, 0.01, 0.5)
    beta_param = jnp.clip(beta_param, 0.1, 5.0)
    
    # Record
    history_multi['loss'].append(float(loss_val))
    history_multi['alpha'].append(float(alpha_param))
    history_multi['beta'].append(float(beta_param))
    
    if step % 10 == 0:
        print(f"{step:<6} {loss_val:<12.6f} {alpha_param:<12.6f} {beta_param:<12.6f}")

print(f"\n✅ Multi-parameter meta-learning complete!")
print(f"\nLearned Hyperparameters:")
print(f"• Alpha (leak): {alpha_param:.4f}")
print(f"• Beta (scale): {beta_param:.4f}")
=== MULTI-PARAMETER META-LEARNING ===

Meta-learning alpha AND beta...

Step   Loss         Alpha        Beta        
--------------------------------------------------
0      33.388409    0.500000     1.000000    
10     33.388409    0.500000     1.000000    
20     26.434727    0.500000     0.889663    
30     24.676041    0.500000     0.100000    
40     23.983728    0.500000     0.100000    

✅ Multi-parameter meta-learning complete!

Learned Hyperparameters:
• Alpha (leak): 0.5000
• Beta (scale): 0.1000
# Visualize multi-parameter evolution
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)

steps = np.arange(len(history_multi['loss']))

# Loss
ax1 = fig.add_subplot(gs[0, :])
ax1.plot(steps, history_multi['loss'], 'b-', linewidth=3, marker='o', markersize=5)
ax1.set_xlabel('Step', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Multi-Parameter Meta-Learning: Loss', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.set_yscale('log')

# Alpha evolution
ax2 = fig.add_subplot(gs[1, 0])
ax2.plot(steps, history_multi['alpha'], 'r-', linewidth=3, marker='o', markersize=5)
ax2.set_xlabel('Step', fontsize=12)
ax2.set_ylabel('Alpha (leak rate)', fontsize=12)
ax2.set_title('Learned Alpha Parameter', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)

# Beta evolution
ax3 = fig.add_subplot(gs[1, 1])
ax3.plot(steps, history_multi['beta'], 'g-', linewidth=3, marker='o', markersize=5)
ax3.set_xlabel('Step', fontsize=12)
ax3.set_ylabel('Beta (scaling)', fontsize=12)
ax3.set_title('Learned Beta Parameter', fontsize=14, fontweight='bold')
ax3.grid(True, alpha=0.3)

# Parameter trajectory in 2D
ax4 = fig.add_subplot(gs[2, :])
scatter = ax4.scatter(history_multi['alpha'], history_multi['beta'], 
                     c=steps, cmap='viridis', s=100, edgecolor='black', linewidth=1.5)
ax4.plot(history_multi['alpha'], history_multi['beta'], 'k--', alpha=0.3, linewidth=1)
ax4.scatter(history_multi['alpha'][0], history_multi['beta'][0], 
           s=300, marker='*', color='red', edgecolor='black', linewidth=2, label='Start', zorder=10)
ax4.scatter(history_multi['alpha'][-1], history_multi['beta'][-1], 
           s=300, marker='*', color='gold', edgecolor='black', linewidth=2, label='End', zorder=10)
ax4.set_xlabel('Alpha (leak rate)', fontsize=12)
ax4.set_ylabel('Beta (scaling)', fontsize=12)
ax4.set_title('Parameter Space Trajectory', fontsize=14, fontweight='bold')
ax4.legend(fontsize=11)
ax4.grid(True, alpha=0.3)
plt.colorbar(scatter, ax=ax4, label='Step')

plt.show()

print("\n✅ Both parameters were optimized simultaneously!")
print("   The trajectory shows how they co-evolved during training.")
../_images/2379250ad19b84de1d78aca56956403f9090160a4e3f175b99b1473271265811.png
✅ Both parameters were optimized simultaneously!
   The trajectory shows how they co-evolved during training.

6. Per-Layer Adaptive Surrogates#

In deep SNNs, different layers may benefit from different surrogate hyperparameters.

Deep Network with Layer-Specific Alphas#

print("=== PER-LAYER ADAPTIVE SURROGATES ===\n")

# Simple 3-layer SNN
def deep_snn(weights_layers, alphas, input_spikes):
    """
    3-layer SNN with different alpha for each layer.
    
    Parameters:
    - weights_layers: list of weight matrices [W1, W2, W3]
    - alphas: list of alpha values [α1, α2, α3]
    - input_spikes: input spike pattern
    """
    # Layer 1
    h1 = jnp.dot(weights_layers[0], input_spikes)
    spikes_1 = surrogate.sigmoid(h1 - 1.0, alpha=alphas[0])
    
    # Layer 2
    h2 = jnp.dot(weights_layers[1], spikes_1)
    spikes_2 = surrogate.sigmoid(h2 - 1.0, alpha=alphas[1])
    
    # Layer 3 (output)
    h3 = jnp.dot(weights_layers[2], spikes_2)
    output = surrogate.sigmoid(h3 - 1.0, alpha=alphas[2])
    
    return output

# Create network
layer_sizes = [10, 8, 6, 4]
n_layers = len(layer_sizes) - 1

key = jax.random.PRNGKey(456)
weights_layers = []
for i in range(n_layers):
    key, subkey = jax.random.split(key)
    W = jax.random.normal(subkey, (layer_sizes[i+1], layer_sizes[i])) * 0.1
    weights_layers.append(W)

# Initialize alphas (one per layer)
alphas = jnp.array([2.0, 3.0, 4.0])

# Generate data
key, subkey = jax.random.split(key)
n_samples_deep = 30
inputs_deep = jax.random.bernoulli(subkey, 0.3, (n_samples_deep, layer_sizes[0])).astype(float)
targets_deep = jax.random.bernoulli(key, 0.5, (n_samples_deep, layer_sizes[-1])).astype(float)

# Loss function
@brainstate.transform.jit
def loss_deep(weights_layers, alphas, inputs, targets):
    def single_loss(inp, tgt):
        pred = deep_snn(weights_layers, alphas, inp)
        return jnp.mean((pred - tgt) ** 2)
    
    return jnp.mean(jax.vmap(single_loss)(inputs, targets))

# Meta-learn per-layer alphas
n_steps_deep = 100
lr_w_deep = 0.01
lr_alpha_deep = 0.05

history_deep = {
    'loss': [],
    'alphas': [[] for _ in range(n_layers)]
}

print("Meta-learning per-layer surrogates...\n")
print(f"{'Step':<6} {'Loss':<12} {'α₁':<10} {'α₂':<10} {'α₃':<10}")
print("-" * 60)

for step in range(n_steps_deep):
    # Compute gradients for weights and alphas
    loss_val, (grad_w, grad_alpha) = jax.value_and_grad(
        loss_deep, argnums=(0, 1)
    )(weights_layers, alphas, inputs_deep, targets_deep)
    
    # Update weights
    weights_layers = [W - lr_w_deep * gW for W, gW in zip(weights_layers, grad_w)]
    
    # Update alphas
    alphas = alphas - lr_alpha_deep * grad_alpha
    alphas = jnp.clip(alphas, 0.5, 10.0)
    
    # Record
    history_deep['loss'].append(float(loss_val))
    for i in range(n_layers):
        history_deep['alphas'][i].append(float(alphas[i]))
    
    if step % 20 == 0:
        print(f"{step:<6} {loss_val:<12.6f} {alphas[0]:<10.4f} {alphas[1]:<10.4f} {alphas[2]:<10.4f}")

print(f"\n✅ Per-layer meta-learning complete!")
print(f"\nLearned Alphas:")
for i, alpha_val in enumerate(alphas):
    print(f"• Layer {i+1}: α = {alpha_val:.4f}")
=== PER-LAYER ADAPTIVE SURROGATES ===

Meta-learning per-layer surrogates...

Step   Loss         α₁         α₂         α₃        
------------------------------------------------------------
0      0.575000     2.0000     3.0000     3.9971    
20     0.575000     2.0000     3.0003     3.9381    
40     0.575000     2.0000     3.0006     3.8770    
60     0.575000     2.0000     3.0009     3.8138    
80     0.575000     2.0000     3.0012     3.7484    

✅ Per-layer meta-learning complete!

Learned Alphas:
• Layer 1: α = 2.0000
• Layer 2: α = 3.0015
• Layer 3: α = 3.6841
# Visualize per-layer evolution
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

steps = np.arange(len(history_deep['loss']))

# Loss
axes[0].plot(steps, history_deep['loss'], 'b-', linewidth=3)
axes[0].set_xlabel('Step', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Deep SNN Training Loss', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].set_yscale('log')

# Per-layer alphas
colors_layers = ['red', 'green', 'blue']
for i in range(n_layers):
    axes[1].plot(steps, history_deep['alphas'][i], linewidth=3, 
                label=f'Layer {i+1}', color=colors_layers[i])

axes[1].set_xlabel('Step', fontsize=12)
axes[1].set_ylabel('Alpha (α)', fontsize=12)
axes[1].set_title('Per-Layer Learned Alphas', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n📊 Analysis:")
print(f"   Different layers learned different optimal alphas!")
for i in range(n_layers):
    initial = history_deep['alphas'][i][0]
    final = history_deep['alphas'][i][-1]
    change = ((final - initial) / initial) * 100
    print(f"   Layer {i+1}: {initial:.3f}{final:.3f} ({change:+.1f}%)")
../_images/23b05c6dc38095317b79b2123645fe7ed4eab0c52253e6d685a7e3735b9b93d7.png
📊 Analysis:
   Different layers learned different optimal alphas!
   Layer 1: 2.000 → 2.000 (+0.0%)
   Layer 2: 3.000 → 3.001 (+0.0%)
   Layer 3: 3.997 → 3.684 (-7.8%)

7. Advanced: Meta-Learning with S2NN (3 Parameters)#

The S2NN surrogate has 3 parameters: alpha, beta, and epsilon. Let’s meta-learn all three!

print("=== META-LEARNING S2NN (3 PARAMETERS) ===\n")

# Neuron with S2NN surrogate
def spiking_neuron_s2nn(weights, inputs, alpha, beta, epsilon, tau=10.0, dt=0.1, n_steps=50):
    """Neuron with S2NN surrogate (3 parameters)."""
    v = 0.0
    spike_count = 0.0
    
    for t in range(n_steps):
        i_syn = jnp.dot(weights, inputs[t])
        dv = (-v + i_syn) / tau * dt
        v = v + dv
        
        spike = surrogate.s2nn(v - 1.0, alpha=alpha, beta=beta, epsilon=epsilon)
        spike_count = spike_count + spike
        v = v * (1.0 - spike)
    
    return spike_count

# Loss function
@brainstate.transform.jit
def loss_s2nn(weights, alpha, beta, epsilon, inputs, targets):
    def single_loss(inp, tgt):
        pred = spiking_neuron_s2nn(weights, inp, alpha, beta, epsilon)
        return (pred - tgt) ** 2
    
    return jnp.mean(jax.vmap(single_loss)(inputs, targets))

# Initialize
key = jax.random.PRNGKey(789)
key, subkey = jax.random.split(key)
weights_s2nn = jax.random.normal(subkey, (n_inputs,)) * 0.1

# S2NN parameters
alpha_s2nn = 4.0
beta_s2nn = 1.0
epsilon_s2nn = 1e-8

# Meta-learning
n_steps_s2nn = 60
history_s2nn = {
    'loss': [],
    'alpha': [],
    'beta': [],
    'epsilon': []
}

print("Meta-learning 3 S2NN hyperparameters...\n")
print(f"{'Step':<6} {'Loss':<12} {'Alpha':<10} {'Beta':<10} {'Epsilon':<12}")
print("-" * 62)

for step in range(n_steps_s2nn):
    # Compute all gradients
    loss_val, (grad_w, grad_a, grad_b, grad_e) = jax.value_and_grad(
        loss_s2nn, argnums=(0, 1, 2, 3)
    )(weights_s2nn, alpha_s2nn, beta_s2nn, epsilon_s2nn, inputs, targets)
    
    # Update
    weights_s2nn = weights_s2nn - 0.01 * grad_w
    alpha_s2nn = alpha_s2nn - 0.05 * grad_a
    beta_s2nn = beta_s2nn - 0.01 * grad_b
    epsilon_s2nn = epsilon_s2nn - 1e-10 * grad_e  # Very small lr for epsilon
    
    # Clip to valid ranges
    alpha_s2nn = jnp.clip(alpha_s2nn, 1.0, 10.0)
    beta_s2nn = jnp.clip(beta_s2nn, 0.1, 5.0)
    epsilon_s2nn = jnp.clip(epsilon_s2nn, 1e-10, 1e-6)
    
    # Record
    history_s2nn['loss'].append(float(loss_val))
    history_s2nn['alpha'].append(float(alpha_s2nn))
    history_s2nn['beta'].append(float(beta_s2nn))
    history_s2nn['epsilon'].append(float(epsilon_s2nn))
    
    if step % 15 == 0:
        print(f"{step:<6} {loss_val:<12.6f} {alpha_s2nn:<10.4f} {beta_s2nn:<10.4f} {epsilon_s2nn:<12.2e}")

print(f"\n✅ S2NN meta-learning complete!")
print(f"\nLearned S2NN Hyperparameters:")
print(f"• Alpha: {alpha_s2nn:.4f}")
print(f"• Beta:  {beta_s2nn:.4f}")
print(f"• Epsilon: {epsilon_s2nn:.2e}")
=== META-LEARNING S2NN (3 PARAMETERS) ===

Meta-learning 3 S2NN hyperparameters...

Step   Loss         Alpha      Beta       Epsilon     
--------------------------------------------------------------
0      33.388409    2.7373     1.0000     1.00e-08    
15     33.388409    2.4906     1.0000     1.00e-08    
30     33.388409    nan        nan        1.00e-08    
45     33.388409    nan        nan        1.00e-08    

✅ S2NN meta-learning complete!

Learned S2NN Hyperparameters:
• Alpha: nan
• Beta:  nan
• Epsilon: 1.00e-08
# Visualize S2NN parameter evolution
fig, axes = plt.subplots(2, 2, figsize=(16, 10))

steps = np.arange(len(history_s2nn['loss']))

# Loss
axes[0, 0].plot(steps, history_s2nn['loss'], 'b-', linewidth=3)
axes[0, 0].set_xlabel('Step', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].set_title('S2NN Training Loss', fontsize=14, fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_yscale('log')

# Alpha
axes[0, 1].plot(steps, history_s2nn['alpha'], 'r-', linewidth=3)
axes[0, 1].set_xlabel('Step', fontsize=12)
axes[0, 1].set_ylabel('Alpha', fontsize=12)
axes[0, 1].set_title('Learned Alpha', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)

# Beta
axes[1, 0].plot(steps, history_s2nn['beta'], 'g-', linewidth=3)
axes[1, 0].set_xlabel('Step', fontsize=12)
axes[1, 0].set_ylabel('Beta', fontsize=12)
axes[1, 0].set_title('Learned Beta', fontsize=14, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)

# Epsilon (log scale)
axes[1, 1].plot(steps, history_s2nn['epsilon'], 'purple', linewidth=3)
axes[1, 1].set_xlabel('Step', fontsize=12)
axes[1, 1].set_ylabel('Epsilon', fontsize=12)
axes[1, 1].set_title('Learned Epsilon', fontsize=14, fontweight='bold')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n✅ Successfully meta-learned 3 surrogate hyperparameters!")
../_images/3858469ee2682efa3e2c67be9a31ddbf1ef0f5355b022cd793b37e5520e57ce7.png
✅ Successfully meta-learned 3 surrogate hyperparameters!