Activation Functions and Normalization#

Activation functions and normalization layers are critical components that enable deep neural networks to learn complex patterns and train stably.

In this tutorial, you will learn:

  • 🎯 Activation Functions - ReLU, GELU, Sigmoid, Tanh and variants

  • 📊 Batch Normalization - Stabilizing training with BatchNorm

  • 📐 Layer Normalization - Alternative normalization for sequences

  • 🔲 Group Normalization - For small batch sizes

  • ⚖️ When to use each - Practical guidelines

Why Are These Important?#

Activation Functions introduce non-linearity, allowing networks to learn complex patterns
Normalization Layers stabilize training, accelerate convergence, and act as regularizers

import brainstate
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

1. Activation Functions#

Activation functions determine the output of a neuron given its input.

ReLU Family#

Standard ReLU#

# ReLU: max(0, x)
relu = brainstate.nn.ReLU()

x = jnp.linspace(-3, 3, 100)
y = relu(x)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(x, y, linewidth=2, label='ReLU')
plt.axhline(0, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('ReLU Activation', fontweight='bold')
plt.legend()

# Test with data
test_input = jnp.array([-2, -1, 0, 1, 2])
test_output = relu(test_input)
plt.subplot(1, 2, 2)
plt.stem(test_input, test_output, basefmt=' ')
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('ReLU on Sample Points', fontweight='bold')
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Input:  {test_input}")
print(f"Output: {test_output}")
print("\n✅ Advantages: Simple, fast, no gradient vanishing for positive values")
print("⚠️  Issue: 'Dying ReLU' problem when neurons always output 0")
../../_images/5fcecf1323481445bbef51cd3bb3b3af70d5eb863e270b69f587361815c934ba.png
Input:  [-2 -1  0  1  2]
Output: [0 0 0 1 2]

✅ Advantages: Simple, fast, no gradient vanishing for positive values
⚠️  Issue: 'Dying ReLU' problem when neurons always output 0

Leaky ReLU#

# LeakyReLU: max(alpha * x, x)
leaky_relu = brainstate.nn.LeakyReLU(negative_slope=0.1)

y_leaky = leaky_relu(x)

plt.figure(figsize=(8, 5))
plt.plot(x, y, linewidth=2, label='ReLU', alpha=0.5)
plt.plot(x, y_leaky, linewidth=2, label='Leaky ReLU (α=0.1)')
plt.axhline(0, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('ReLU vs Leaky ReLU', fontweight='bold')
plt.legend()
plt.show()

print("✅ Advantage: Allows small gradient for negative values")
print("   Helps prevent dying ReLU problem")
../../_images/49da3d82251837d49c7900407dc381bab978552e9c2b90a8c5725301b57f3e26.png
✅ Advantage: Allows small gradient for negative values
   Helps prevent dying ReLU problem

Modern Activation Functions#

GELU (Gaussian Error Linear Unit)#

# GELU: Smooth, probabilistic activation
gelu = brainstate.nn.GELU()

y_gelu = gelu(x)

plt.figure(figsize=(8, 5))
plt.plot(x, y, linewidth=2, label='ReLU', alpha=0.5)
plt.plot(x, y_gelu, linewidth=2, label='GELU')
plt.axhline(0, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('ReLU vs GELU', fontweight='bold')
plt.legend()
plt.show()

print("✅ Used in Transformers (BERT, GPT)")
print("✅ Smooth, differentiable everywhere")
print("✅ Stochastic regularization properties")
../../_images/2534b12734570701d2bc7225c783b942745064f39bd44f7280b2df33a33e4023.png
✅ Used in Transformers (BERT, GPT)
✅ Smooth, differentiable everywhere
✅ Stochastic regularization properties

Classic Activations#

Sigmoid and Tanh#

# Sigmoid: 1 / (1 + exp(-x))
sigmoid = brainstate.nn.Sigmoid()

# Tanh: (exp(x) - exp(-x)) / (exp(x) + exp(-x))
tanh = brainstate.nn.Tanh()

y_sigmoid = sigmoid(x)
y_tanh = tanh(x)

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(x, y_sigmoid, linewidth=2, label='Sigmoid', color='blue')
plt.axhline(0.5, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Sigmoid: σ(x) = 1/(1+e⁻ˣ)', fontweight='bold')
plt.legend()
plt.ylim([-0.1, 1.1])

plt.subplot(1, 2, 2)
plt.plot(x, y_tanh, linewidth=2, label='Tanh', color='green')
plt.axhline(0, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Tanh: tanh(x)', fontweight='bold')
plt.legend()
plt.ylim([-1.1, 1.1])

plt.tight_layout()
plt.show()

print("Sigmoid:")
print("  ✅ Output range: (0, 1) - good for probabilities")
print("  ⚠️  Vanishing gradients for large |x|")
print("\nTanh:")
print("  ✅ Output range: (-1, 1) - zero-centered")
print("  ⚠️  Also suffers from vanishing gradients")
print("  📝 Often used in RNN/LSTM gates")
../../_images/2a65dbe7d947b9eec45e3d582e81339f78481ad9644518ce6a16369f2d9cd3aa.png
Sigmoid:
  ✅ Output range: (0, 1) - good for probabilities
  ⚠️  Vanishing gradients for large |x|

Tanh:
  ✅ Output range: (-1, 1) - zero-centered
  ⚠️  Also suffers from vanishing gradients
  📝 Often used in RNN/LSTM gates

Comparing All Activations#

# Compare multiple activations
activations = {
    'ReLU': brainstate.nn.ReLU(),
    'LeakyReLU': brainstate.nn.LeakyReLU(0.1),
    'GELU': brainstate.nn.GELU(),
    'ELU': brainstate.nn.ELU(),
    'Sigmoid': brainstate.nn.Sigmoid(),
    'Tanh': brainstate.nn.Tanh(),
}

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for idx, (name, activation) in enumerate(activations.items()):
    y_act = activation(x)
    axes[idx].plot(x, y_act, linewidth=2.5, color=f'C{idx}')
    axes[idx].axhline(0, color='gray', linestyle='--', alpha=0.3)
    axes[idx].axvline(0, color='gray', linestyle='--', alpha=0.3)
    axes[idx].grid(alpha=0.3)
    axes[idx].set_xlabel('Input', fontsize=10)
    axes[idx].set_ylabel('Output', fontsize=10)
    axes[idx].set_title(name, fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n📊 Activation Function Guide:\n")
print("ReLU:        Default choice, fast, works well in most cases")
print("LeakyReLU:   When dying ReLU is a problem")
print("GELU:        Transformers, NLP models")
print("ELU:         Smooth variant of ReLU with negative values")
print("Sigmoid:     Output layer for binary classification")
print("Tanh:        RNN/LSTM gates, when zero-centered output needed")
../../_images/0895a72d05925f4876eadaaf139b0958e2f6cf983c7b5a4a41b1720aa65491ad.png
📊 Activation Function Guide:

ReLU:        Default choice, fast, works well in most cases
LeakyReLU:   When dying ReLU is a problem
GELU:        Transformers, NLP models
ELU:         Smooth variant of ReLU with negative values
Sigmoid:     Output layer for binary classification
Tanh:        RNN/LSTM gates, when zero-centered output needed

Softmax - For Classification#

# Softmax: converts logits to probabilities
softmax = brainstate.nn.Softmax()

# Example logits from a classifier
logits = jnp.array([2.0, 1.0, 0.5, 3.0, 0.1])
probs = softmax(logits)

print("Logits:       ", logits)
print("Probabilities:", probs)
print(f"Sum of probs:  {jnp.sum(probs):.6f} (should be 1.0)")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

classes = ['Cat', 'Dog', 'Bird', 'Fish', 'Horse']

axes[0].bar(classes, logits, color='steelblue', alpha=0.7)
axes[0].set_ylabel('Logit Value')
axes[0].set_title('Raw Logits', fontweight='bold')
axes[0].grid(axis='y', alpha=0.3)

axes[1].bar(classes, probs, color='coral', alpha=0.7)
axes[1].set_ylabel('Probability')
axes[1].set_title('After Softmax', fontweight='bold')
axes[1].set_ylim([0, 1])
axes[1].grid(axis='y', alpha=0.3)

# Add values on bars
for i, v in enumerate(probs):
    axes[1].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=9)

plt.tight_layout()
plt.show()

print("\n✅ Softmax converts logits to valid probability distribution")
print("   Use for multi-class classification output layer")
Logits:        [2.  1.  0.5 3.  0.1]
Probabilities: [0.22427258 0.08250527 0.05004197 0.60963607 0.03354414]
Sum of probs:  1.000000 (should be 1.0)
../../_images/9353c1f48cac8fac2cde9eb47e6a72ed480060ebad01739bf13077128fd35a99.png
✅ Softmax converts logits to valid probability distribution
   Use for multi-class classification output layer

2. Normalization Layers#

Normalization stabilizes training by controlling the distribution of activations.

Batch Normalization#

# BatchNorm: Normalizes across batch dimension
brainstate.random.seed(42)
batch_norm = brainstate.nn.BatchNorm0d((10,))

print("BatchNorm0d:")
print(batch_norm)

# Create batch of data with varying statistics
batch_size = 32
features = 10
x = brainstate.random.randn(batch_size, features) * 5 + 10  # mean≈10, std≈5

print()
print("Before BatchNorm:")
print(f"  Mean: {jnp.mean(x, axis=0)[:3]}...")
print(f"  Std:  {jnp.std(x, axis=0)[:3]}...")

# Apply batch norm
with brainstate.environ.context(fit=True):
    y = batch_norm(x)

print()
print("After BatchNorm:")
print(f"  Mean: {jnp.mean(y, axis=0)[:3]}...")
print(f"  Std:  {jnp.std(y, axis=0)[:3]}...")

print()
print("✅ Normalizes to ~mean=0, ~std=1 across batch")
print("✅ Learns scale (γ) and shift (β) parameters")
BatchNorm0d:
BatchNorm0d(
  in_size=(10,),
  out_size=(10,),
  affine=True,
  bias_initializer=Constant(
    value=0.0
  ),
  scale_initializer=Constant(
    value=1.0
  ),
  dtype=<class 'numpy.float32'>,
  track_running_stats=True,
  momentum=Array(0.99, dtype=float32),
  epsilon=Array(1.e-05, dtype=float32),
  use_fast_variance=True,
  feature_axes=(0,),
  axis_name=None,
  axis_index_groups=None,
  running_mean=BatchState(
    value=ShapedArray(float32[10])
  ),
  running_var=BatchState(
    value=ShapedArray(float32[10])
  ),
  weight=NormalizationParamState(
    value={
      'bias': ShapedArray(float32[10]),
      'scale': ShapedArray(float32[10])
    }
  )
)

Before BatchNorm:
  Mean: [10.531552 10.173284  9.421941]...
  Std:  [4.4984775 4.850637  4.945446 ]...

After BatchNorm:
  Mean: [ 2.0861626e-07 -5.5879354e-08  2.3096800e-07]...
  Std:  [0.99999964 0.9999995  0.9999993 ]...

✅ Normalizes to ~mean=0, ~std=1 across batch
✅ Learns scale (γ) and shift (β) parameters

Visualizing BatchNorm Effect#

# Generate data with different distributions
brainstate.random.seed(0)
x1 = brainstate.random.randn(1000) * 10 + 50  # High variance, high mean
x2 = brainstate.random.randn(1000) * 2 - 5     # Low variance, negative mean

# Create batch norm (treating as batch dimension)
bn = brainstate.nn.BatchNorm0d((1,))

# Reshape to (batch, features)
x1_batch = x1[:, None]
x2_batch = x2[:, None]

# Apply normalization
with brainstate.environ.context(fit=True):
    y1 = bn(x1_batch).flatten()
    y2 = bn(x2_batch).flatten()

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Distribution 1 - before
axes[0, 0].hist(np.array(x1), bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[0, 0].set_title('Distribution 1 - Before BN', fontweight='bold')
axes[0, 0].set_xlabel('Value')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].axvline(jnp.mean(x1), color='red', linestyle='--', label=f'μ={jnp.mean(x1):.1f}')
axes[0, 0].legend()

# Distribution 1 - after
axes[0, 1].hist(np.array(y1), bins=50, alpha=0.7, color='green', edgecolor='black')
axes[0, 1].set_title('Distribution 1 - After BN', fontweight='bold')
axes[0, 1].set_xlabel('Value')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].axvline(jnp.mean(y1), color='red', linestyle='--', label=f'μ={jnp.mean(y1):.2f}')
axes[0, 1].legend()

# Distribution 2 - before
axes[1, 0].hist(np.array(x2), bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[1, 0].set_title('Distribution 2 - Before BN', fontweight='bold')
axes[1, 0].set_xlabel('Value')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].axvline(jnp.mean(x2), color='red', linestyle='--', label=f'μ={jnp.mean(x2):.1f}')
axes[1, 0].legend()

# Distribution 2 - after
axes[1, 1].hist(np.array(y2), bins=50, alpha=0.7, color='green', edgecolor='black')
axes[1, 1].set_title('Distribution 2 - After BN', fontweight='bold')
axes[1, 1].set_xlabel('Value')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].axvline(jnp.mean(y2), color='red', linestyle='--', label=f'μ={jnp.mean(y2):.2f}')
axes[1, 1].legend()

plt.tight_layout()
plt.show()

print("BatchNorm standardizes different distributions to similar scale")
../../_images/7ff2e5ea338a4ec75c60fc09d105c7c21f4ae5bb666b7b120778ac04d009ef09.png
BatchNorm standardizes different distributions to similar scale

Layer Normalization#

# LayerNorm: Normalizes across features (not batch)
layer_norm = brainstate.nn.LayerNorm((10,))

print("LayerNorm:")
print(layer_norm)

# Single sample
x_single = brainstate.random.randn(10) * 5 + 10

print()
print("Before LayerNorm (single sample):")
print(f"  Values: {x_single}")
print(f"  Mean: {jnp.mean(x_single):.3f}")
print(f"  Std:  {jnp.std(x_single):.3f}")

y_single = layer_norm(x_single)

print()
print("After LayerNorm:")
print(f"  Values: {y_single}")
print(f"  Mean: {jnp.mean(y_single):.3f}")
print(f"  Std:  {jnp.std(y_single):.3f}")

print()
print("✅ LayerNorm works on single samples")
print("✅ Popular in Transformers and RNNs")
print("✅ Independent of batch size")
LayerNorm:
LayerNorm(
  in_size=(10,),
  out_size=(10,),
  feature_axes=(0,),
  reduction_axes=(-1,),
  axis_name=None,
  axis_index_groups=None,
  weight=NormalizationParamState(
    value={
      'bias': ShapedArray(float32[10]),
      'scale': ShapedArray(float32[10])
    }
  ),
  epsilon=1e-06,
  dtype=<class 'numpy.float32'>,
  use_bias=True,
  use_scale=True,
  bias_init=ZeroInit(
    unit=Unit(10.0^0)
  ),
  scale_init=Constant(
    value=1.0
  ),
  use_fast_variance=True
)

Before LayerNorm (single sample):
  Values: [ 3.0611591 13.874271  17.702465  19.20955    5.685323   5.9649186
  8.997379  13.917359   5.0701327 11.3188305]
  Mean: 10.480
  Std:  5.313

After LayerNorm:
  Values: [-1.3963605   0.6388257   1.3593478   1.6430035  -0.9024544  -0.84983045
 -0.2790769   0.64693546 -1.0182422   0.15785426]
  Mean: 0.000
  Std:  1.000

✅ LayerNorm works on single samples
✅ Popular in Transformers and RNNs
✅ Independent of batch size

Comparing Normalization Methods#

import pandas as pd

comparison = pd.DataFrame({
    'Method': ['BatchNorm', 'LayerNorm', 'GroupNorm', 'InstanceNorm'],
    'Normalizes Over': ['Batch + Spatial', 'Features', 'Groups', 'Spatial per channel'],
    'Best For': ['CNNs, large batches', 'RNNs, Transformers', 'Small batches', 'Style transfer'],
    'Batch Dependent': ['Yes', 'No', 'No', 'No'],
    'Typical Use': ['Vision', 'NLP', 'Vision (small batch)', 'GANs, Style'],
})

print("\n📊 Normalization Method Comparison:\n")
print(comparison.to_string(index=False))

print("\n\n🎯 Quick Guide:")
print("  • Large batch + CNN → BatchNorm")
print("  • Small batch + CNN → GroupNorm")
print("  • Sequences/RNN/Transformer → LayerNorm")
print("  • Single image inference → LayerNorm or GroupNorm")
📊 Normalization Method Comparison:

      Method     Normalizes Over            Best For Batch Dependent          Typical Use
   BatchNorm     Batch + Spatial CNNs, large batches             Yes               Vision
   LayerNorm            Features  RNNs, Transformers              No                  NLP
   GroupNorm              Groups       Small batches              No Vision (small batch)
InstanceNorm Spatial per channel      Style transfer              No          GANs, Style


🎯 Quick Guide:
  • Large batch + CNN → BatchNorm
  • Small batch + CNN → GroupNorm
  • Sequences/RNN/Transformer → LayerNorm
  • Single image inference → LayerNorm or GroupNorm

3. Putting It All Together#

Building a complete network with activations and normalization:

import brainunit as u


class ModernCNN(brainstate.nn.Module):
    """CNN with modern activations and normalization."""

    def __init__(self, num_classes=10):
        super().__init__()

        # Block 1: Conv + BatchNorm + GELU
        self.conv1 = brainstate.nn.Conv2d((32, 32, 3), out_channels=64, kernel_size=(3, 3), padding='SAME')
        self.bn1 = brainstate.nn.BatchNorm2d((32, 32, 64))
        self.act1 = brainstate.nn.GELU()
        self.pool1 = brainstate.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        # Block 2
        self.conv2 = brainstate.nn.Conv2d((16, 16, 64), out_channels=128, kernel_size=(3, 3), padding='SAME')
        self.bn2 = brainstate.nn.BatchNorm2d((16, 16, 128))
        self.act2 = brainstate.nn.GELU()
        self.pool2 = brainstate.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), in_size=self.bn2.out_size)

        # Classifier
        self.fc1 = brainstate.nn.Linear((128 * 8 * 8,), (256,))
        self.ln = brainstate.nn.LayerNorm((256,))
        self.act3 = brainstate.nn.GELU()
        self.dropout = brainstate.nn.Dropout(prob=0.5)
        self.fc2 = brainstate.nn.Linear((256,), (num_classes,))

    def update(self, x):
        # Block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.pool1(x)

        # Block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        x = self.pool2(x)

        # Classifier
        x = u.math.flatten(x, start_axis=1)
        x = self.fc1(x)
        x = self.ln(x)
        x = self.act3(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x

# Create and test
brainstate.random.seed(0)
model = ModernCNN(num_classes=10)

# Forward pass
x = brainstate.random.randn(4, 32, 32, 3)  # 4 images
with brainstate.environ.context(fit=True) as env:
    logits = model(x)

print("Modern CNN with GELU + BatchNorm + LayerNorm:")
print(model)
print()
print("Input:", x.shape)
print("Output:", logits.shape)
print()
print("Logits:", logits[0])
Modern CNN with GELU + BatchNorm + LayerNorm:
ModernCNN(
  conv1=Conv2d(
    in_size=(32, 32, 3),
    out_size=(32, 32, 64),
    channel_first=False,
    channels_last=True,
    in_channels=3,
    out_channels=64,
    stride=(1, 1),
    kernel_size=(3, 3),
    lhs_dilation=(1, 1),
    rhs_dilation=(1, 1),
    groups=1,
    dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)),
    padding=SAME,
    kernel_shape=(3, 3, 3, 64),
    w_mask=None,
    w_initializer=XavierNormal(
      scale=1.0,
      mode='fan_avg',
      in_axis=-2,
      out_axis=-1,
      distribution='truncated_normal',
      rng=RandomState([1846506472 1430574090]),
      unit=Unit(10.0^0)
    ),
    b_initializer=None,
    weight=ParamState(
      value={
        'weight': ShapedArray(float32[3,3,3,64])
      }
    )
  ),
  bn1=BatchNorm2d(
    in_size=(32, 32, 64),
    out_size=(32, 32, 64),
    affine=True,
    bias_initializer=Constant(
      value=0.0
    ),
    scale_initializer=Constant(
      value=1.0
    ),
    dtype=<class 'numpy.float32'>,
    track_running_stats=True,
    momentum=Array(0.99, dtype=float32),
    epsilon=Array(1.e-05, dtype=float32),
    use_fast_variance=True,
    feature_axes=(2,),
    axis_name=None,
    axis_index_groups=None,
    running_mean=BatchState(
      value=ShapedArray(float32[1,1,64])
    ),
    running_var=BatchState(
      value=ShapedArray(float32[1,1,64])
    ),
    weight=NormalizationParamState(
      value={
        'bias': ShapedArray(float32[1,1,64]),
        'scale': ShapedArray(float32[1,1,64])
      }
    )
  ),
  act1=GELU(approximate=False),
  pool1=MaxPool2d(
    init_value=-inf,
    computation=<function max at 0x000001C17F834220>,
    pool_dim=2,
    return_indices=False,
    kernel_size=(2, 2),
    stride=(2, 2),
    padding=VALID,
    channel_axis=-1
  ),
  conv2=Conv2d(
    in_size=(16, 16, 64),
    out_size=(16, 16, 128),
    channel_first=False,
    channels_last=True,
    in_channels=64,
    out_channels=128,
    stride=(1, 1),
    kernel_size=(3, 3),
    lhs_dilation=(1, 1),
    rhs_dilation=(1, 1),
    groups=1,
    dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)),
    padding=SAME,
    kernel_shape=(3, 3, 64, 128),
    w_mask=None,
    w_initializer=XavierNormal(
      scale=1.0,
      mode='fan_avg',
      in_axis=-2,
      out_axis=-1,
      distribution='truncated_normal',
      rng=RandomState([1846506472 1430574090]),
      unit=Unit(10.0^0)
    ),
    b_initializer=None,
    weight=ParamState(
      value={
        'weight': ShapedArray(float32[3,3,64,128])
      }
    )
  ),
  bn2=BatchNorm2d(
    in_size=(16, 16, 128),
    out_size=(16, 16, 128),
    affine=True,
    bias_initializer=Constant(
      value=0.0
    ),
    scale_initializer=Constant(
      value=1.0
    ),
    dtype=<class 'numpy.float32'>,
    track_running_stats=True,
    momentum=Array(0.99, dtype=float32),
    epsilon=Array(1.e-05, dtype=float32),
    use_fast_variance=True,
    feature_axes=(2,),
    axis_name=None,
    axis_index_groups=None,
    running_mean=BatchState(
      value=ShapedArray(float32[1,1,128])
    ),
    running_var=BatchState(
      value=ShapedArray(float32[1,1,128])
    ),
    weight=NormalizationParamState(
      value={
        'bias': ShapedArray(float32[1,1,128]),
        'scale': ShapedArray(float32[1,1,128])
      }
    )
  ),
  act2=GELU(approximate=False),
  pool2=MaxPool2d(
    in_size=(16, 16, 128),
    out_size=(8, 8, 128),
    init_value=-inf,
    computation=<function max at 0x000001C17F834220>,
    pool_dim=2,
    return_indices=False,
    kernel_size=(2, 2),
    stride=(2, 2),
    padding=VALID,
    channel_axis=-1
  ),
  fc1=Linear(
    in_size=(8192,),
    out_size=(256,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[256]),
        'weight': ShapedArray(float32[8192,256])
      }
    )
  ),
  ln=LayerNorm(
    in_size=(256,),
    out_size=(256,),
    feature_axes=(0,),
    reduction_axes=(-1,),
    axis_name=None,
    axis_index_groups=None,
    weight=NormalizationParamState(
      value={
        'bias': ShapedArray(float32[256]),
        'scale': ShapedArray(float32[256])
      }
    ),
    epsilon=1e-06,
    dtype=<class 'numpy.float32'>,
    use_bias=True,
    use_scale=True,
    bias_init=ZeroInit(
      unit=Unit(10.0^0)
    ),
    scale_init=Constant(
      value=1.0
    ),
    use_fast_variance=True
  ),
  act3=GELU(approximate=False),
  dropout=Dropout(
    prob=0.5,
    broadcast_dims=()
  ),
  fc2=Linear(
    in_size=(256,),
    out_size=(10,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[10]),
        'weight': ShapedArray(float32[256,10])
      }
    )
  )
)

Input: (4, 32, 32, 3)
Output: (4, 10)

Logits: [ 3.3256447  -1.429272    0.7167457   1.7075744   0.22491284 -1.362806
 -0.27098486 -0.27673492 -1.0504798  -1.4831724 ]

Summary#

In this tutorial, you learned:

Activation Functions

  • ReLU family (ReLU, Leaky ReLU, ELU)

  • Modern activations (GELU, SiLU)

  • Classic activations (Sigmoid, Tanh)

  • Softmax for classification

Normalization Layers

  • BatchNorm for large-batch training

  • LayerNorm for sequences and transformers

  • GroupNorm for small-batch scenarios

Practical Guidelines

  • When to use each activation

  • When to use each normalization

  • How to combine them effectively

Quick Reference Card#

Task

Activation

Normalization

CNN (large batch)

ReLU/GELU

BatchNorm

CNN (small batch)

ReLU/GELU

GroupNorm

Transformer/NLP

GELU

LayerNorm

RNN/LSTM

Tanh (gates)

LayerNorm

Binary output

Sigmoid

-

Multi-class output

Softmax

-

Best Practices#

  1. 🎯 Use ReLU/GELU as default activations

  2. 📊 Add normalization after conv/linear layers

  3. Order: Conv/Linear → Norm → Activation

  4. 🔍 Experiment with activation functions

  5. 📝 Use appropriate normalization for your batch size

Next Steps#

Continue with:

  • Recurrent Networks - Handle sequential data

  • Training - Optimize with gradient descent

  • Advanced Architectures - ResNets, Transformers