Module System Protocol#

The module system is the foundation for building neural networks in BrainState. It provides a clean, object-oriented interface for organizing stateful computations.

In this tutorial, you will learn:

  • 🏗️ The Module base class and its role

  • 🔨 How to create custom modules

  • 🧩 Module composition and nesting

  • 🎯 Parameter management and initialization

  • 📦 Working with module hierarchies

Why Modules?#

Modules (via brainstate.nn.Module) provide:

Automatic state management - States are tracked automatically
Clean abstractions - Encapsulate related computations
Reusability - Build once, use everywhere
Composability - Combine simple modules into complex systems

import brainstate
import jax.numpy as jnp
import matplotlib.pyplot as plt

1. The Module Base Class#

brainstate.nn.Module is the base class for all modules in BrainState. It provides:

  • Automatic registration of child modules

  • State collection and management

  • Pretty printing and inspection

  • Integration with JAX transformations

Creating Your First Module#

The simplest module inherits from Module and implements update():

class SimpleModule(brainstate.nn.Module):
    """A minimal module that adds a constant."""
    
    def __init__(self, constant=1.0):
        super().__init__()  # Always call parent __init__
        self.constant = constant
    
    def update(self, x):
        return x + self.constant

# Create and use the module
module = SimpleModule(constant=5.0)
result = module(jnp.array([1.0, 2.0, 3.0]))

print("Input:", jnp.array([1.0, 2.0, 3.0]))
print("Output:", result)
print("\nModule:")
print(module)
Input: [1. 2. 3.]
Output: [6. 7. 8.]

Module:
SimpleModule(
  constant=5.0
)

Adding States to Modules#

Modules become powerful when they contain states:

class Counter(brainstate.nn.Module):
    """A module that counts how many times it's called."""
    
    def __init__(self):
        super().__init__()
        # Create a state to track the count
        self.count = brainstate.ShortTermState(jnp.array(0))
    
    def update(self, x):
        # Increment counter
        self.count.value = self.count.value + 1
        # Return input with count
        return x * self.count.value

# Test the counter
counter = Counter()
print("Initial count:", counter.count.value)

for i in range(5):
    result = counter(jnp.array(10.0))
    print(f"Call {i+1}: count={counter.count.value}, result={result}")
Initial count: 0
Call 1: count=1, result=10.0
Call 2: count=2, result=20.0
Call 3: count=3, result=30.0
Call 4: count=4, result=40.0
Call 5: count=5, result=50.0

2. Creating Custom Modules#

Let’s build a complete linear layer from scratch to understand module design:

Example: Custom Linear Layer#

class Linear(brainstate.nn.Module):
    """A linear transformation: y = W @ x + b"""
    
    def __init__(self, in_features, out_features, use_bias=True):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        
        # Initialize weight with Xavier/Glorot initialization
        std = jnp.sqrt(2.0 / (in_features + out_features))
        self.weight = brainstate.ParamState(
            brainstate.random.randn(in_features, out_features) * std
        )
        
        # Initialize bias to zero
        if use_bias:
            self.bias = brainstate.ParamState(jnp.zeros(out_features))
    
    def update(self, x):
        """Forward pass.
        
        Args:
            x: Input tensor of shape (..., in_features)
            
        Returns:
            Output tensor of shape (..., out_features)
        """
        out = x @ self.weight.value
        if self.use_bias:
            out = out + self.bias.value
        return out
    
    def __repr__(self):
        return f"Linear(in_features={self.in_features}, out_features={self.out_features}, use_bias={self.use_bias})"

# Create and test the linear layer
brainstate.random.seed(42)
linear = Linear(in_features=5, out_features=3)

# Forward pass
x = jnp.ones(5)
y = linear(x)

print("Module:")
print(linear)
print(f"\nWeight shape: {linear.weight.value.shape}")
print(f"Bias shape: {linear.bias.value.shape}")
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Output: {y}")
Module:
Linear(in_features=5, out_features=3, use_bias=True)

Weight shape: (5, 3)
Bias shape: (3,)

Input shape: (5,)
Output shape: (3,)
Output: [ 0.3793956  -0.9351347  -0.94997764]

Example: Custom Activation Module#

class LeakyReLU(brainstate.nn.Module):
    """Leaky ReLU activation: y = max(alpha * x, x)"""
    
    def __init__(self, negative_slope=0.01):
        super().__init__()
        self.negative_slope = negative_slope
    
    def update(self, x):
        return jnp.where(x > 0, x, self.negative_slope * x)
    
    def __repr__(self):
        return f"LeakyReLU(negative_slope={self.negative_slope})"

# Test the activation
activation = LeakyReLU(negative_slope=0.1)
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
y = activation(x)

print("Activation:", activation)
print(f"Input:  {x}")
print(f"Output: {y}")

# Visualize
x_plot = jnp.linspace(-3, 3, 100)
y_plot = activation(x_plot)

plt.figure(figsize=(8, 5))
plt.plot(x_plot, y_plot, linewidth=2, label='LeakyReLU(0.1)')
plt.axhline(0, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Leaky ReLU Activation Function')
plt.legend()
plt.show()
Activation: LeakyReLU(negative_slope=0.1)
Input:  [-2. -1.  0.  1.  2.]
Output: [-0.2 -0.1  0.   1.   2. ]
../../_images/5052958a2984d4df1137c32830e04642e1193ac35e5c667c6574cf975a794be5.png

3. Module Composition and Nesting#

The real power of modules comes from composing them into larger networks.

Sequential Composition#

Build a network by stacking layers sequentially:

class MLP(brainstate.nn.Module):
    """Multi-layer perceptron with customizable architecture."""
    
    def __init__(self, layer_sizes, activation='relu'):
        super().__init__()
        
        self.layers = []
        
        # Create layers
        for i in range(len(layer_sizes) - 1):
            # Add linear layer
            layer = Linear(layer_sizes[i], layer_sizes[i+1])
            setattr(self, f'layer_{i}', layer)  # Register as attribute
            self.layers.append(layer)
            
            # Add activation (except for last layer)
            if i < len(layer_sizes) - 2:
                if activation == 'relu':
                    act = LeakyReLU(negative_slope=0.0)  # Standard ReLU
                else:
                    act = LeakyReLU(negative_slope=0.01)
                setattr(self, f'activation_{i}', act)
                self.layers.append(act)
    
    def update(self, x):
        """Forward pass through all layers."""
        for layer in self.layers:
            x = layer(x)
        return x

# Create a 3-layer MLP
brainstate.random.seed(0)
mlp = MLP(layer_sizes=[10, 64, 32, 5])

# Forward pass
x = brainstate.random.randn(10)
y = mlp(x)

print("MLP Architecture:")
print(mlp)
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Output: {y}")
MLP Architecture:
MLP(
  layers=[
    Linear(in_features=10, out_features=64, use_bias=True),
    LeakyReLU(negative_slope=0.0),
    Linear(in_features=64, out_features=32, use_bias=True),
    LeakyReLU(negative_slope=0.0),
    Linear(in_features=32, out_features=5, use_bias=True)
  ],
  layer_0=Linear(in_features=10, out_features=64, use_bias=True),
  activation_0=LeakyReLU(negative_slope=0.0),
  layer_1=Linear(in_features=64, out_features=32, use_bias=True),
  activation_1=LeakyReLU(negative_slope=0.0),
  layer_2=Linear(in_features=32, out_features=5, use_bias=True)
)

Input shape: (10,)
Output shape: (5,)
Output: [-0.49218872  0.5558434  -0.6296929   0.25295696  0.37388656]

Residual Connections#

Implement skip connections for deeper networks:

class ResidualBlock(brainstate.nn.Module):
    """Residual block: y = F(x) + x"""
    
    def __init__(self, dim):
        super().__init__()
        
        # Two linear layers with activation in between
        self.linear1 = Linear(dim, dim)
        self.activation = LeakyReLU(0.0)
        self.linear2 = Linear(dim, dim)
    
    def update(self, x):
        # Compute residual
        residual = x
        
        # Forward through layers
        out = self.linear1(x)
        out = self.activation(out)
        out = self.linear2(out)
        
        # Add residual
        return out + residual

class ResNet(brainstate.nn.Module):
    """Simple ResNet with multiple residual blocks."""
    
    def __init__(self, input_dim, hidden_dim, output_dim, n_blocks=3):
        super().__init__()
        
        # Input projection
        self.input_proj = Linear(input_dim, hidden_dim)
        
        # Residual blocks
        self.blocks = []
        for i in range(n_blocks):
            block = ResidualBlock(hidden_dim)
            setattr(self, f'block_{i}', block)
            self.blocks.append(block)
        
        # Output projection
        self.output_proj = Linear(hidden_dim, output_dim)
    
    def update(self, x):
        # Project to hidden dimension
        x = self.input_proj(x)
        
        # Pass through residual blocks
        for block in self.blocks:
            x = block(x)
        
        # Project to output
        x = self.output_proj(x)
        return x

# Create ResNet
brainstate.random.seed(0)
resnet = ResNet(input_dim=10, hidden_dim=32, output_dim=5, n_blocks=3)

# Forward pass
x = brainstate.random.randn(10)
y = resnet(x)

print("ResNet:")
print(resnet)
print(f"\nOutput shape: {y.shape}")
ResNet:
ResNet(
  input_proj=Linear(in_features=10, out_features=32, use_bias=True),
  blocks=[
    ResidualBlock(
      linear1=Linear(in_features=32, out_features=32, use_bias=True),
      activation=LeakyReLU(negative_slope=0.0),
      linear2=Linear(in_features=32, out_features=32, use_bias=True)
    ),
    ResidualBlock(
      linear1=Linear(in_features=32, out_features=32, use_bias=True),
      activation=LeakyReLU(negative_slope=0.0),
      linear2=Linear(in_features=32, out_features=32, use_bias=True)
    ),
    ResidualBlock(
      linear1=Linear(in_features=32, out_features=32, use_bias=True),
      activation=LeakyReLU(negative_slope=0.0),
      linear2=Linear(in_features=32, out_features=32, use_bias=True)
    )
  ],
  block_0=ResidualBlock(...),
  block_1=ResidualBlock(...),
  block_2=ResidualBlock(...),
  output_proj=Linear(in_features=32, out_features=5, use_bias=True)
)

Output shape: (5,)

4. Automatic Input/Output Size Inference#

One of BrainState’s most powerful features is automatic input/output size inference. Every brainstate.nn.Module instance has in_size and out_size properties that track the shape of data flowing through the module (excluding the batch dimension).

Key Concepts#

in_size: Input shape without batch dimension
out_size: Output shape without batch dimension (automatically inferred)
Automatic propagation: When in_size is known, out_size is computed automatically
Sequential composition: Output size of one layer becomes input size of next layer

This mechanism eliminates the need to manually calculate dimensions through network layers, making it much easier to build complex architectures.

Example 1: Basic Size Inference#

# Create a linear layer with explicit in_size and out_size
layer = brainstate.nn.Linear(in_size=(10,), out_size=(5,))

print("Layer:", layer)
print(f"Input size:  {layer.in_size}")
print(f"Output size: {layer.out_size}")

# Forward pass with batch dimension
x = brainstate.random.randn(32, 10)  # (batch_size, in_features)
y = layer(x)

print(f"\nInput shape:  {x.shape}  (batch_size=32, in_features=10)")
print(f"Output shape: {y.shape}  (batch_size=32, out_features=5)")
print("\nNote: in_size and out_size DO NOT include the batch dimension!")
Layer: Linear(
  in_size=(10,),
  out_size=(5,),
  w_mask=None,
  weight=ParamState(
    value={
      'bias': ShapedArray(float32[5]),
      'weight': ShapedArray(float32[10,5])
    }
  )
)
Input size:  (10,)
Output size: (5,)

Input shape:  (32, 10)  (batch_size=32, in_features=10)
Output shape: (32, 5)  (batch_size=32, out_features=5)

Note: in_size and out_size DO NOT include the batch dimension!

Example 2: Size Inference with Convolution#

Convolution layers automatically compute output spatial dimensions based on:

  • Input spatial size

  • Kernel size

  • Stride

  • Padding mode

# Create a 2D convolution layer
conv = brainstate.nn.Conv2d(
    in_size=(28, 28, 3),      # (height, width, channels)
    out_channels=32,
    kernel_size=3,
    stride=1,
    padding='SAME'
)

print("Conv2d Layer:")
print(f"  in_size:  {conv.in_size}")
print(f"  out_size: {conv.out_size}")
print(f"\n  Input:  (H, W, C) = {conv.in_size}")
print(f"  Output: (H', W', C') = {conv.out_size}")
print("\nWith 'SAME' padding and stride=1, spatial dimensions are preserved!")

# Test with different padding
conv_valid = brainstate.nn.Conv2d(
    in_size=(28, 28, 3),
    out_channels=32,
    kernel_size=3,
    stride=2,
    padding='VALID'
)

print(f"\nWith 'VALID' padding and stride=2:")
print(f"  in_size:  {conv_valid.in_size}")
print(f"  out_size: {conv_valid.out_size}")
print("  Spatial dimensions are reduced!")
Conv2d Layer:
  in_size:  (28, 28, 3)
  out_size: (28, 28, 32)

  Input:  (H, W, C) = (28, 28, 3)
  Output: (H', W', C') = (28, 28, 32)

With 'SAME' padding and stride=1, spatial dimensions are preserved!

With 'VALID' padding and stride=2:
  in_size:  (28, 28, 3)
  out_size: (13, 13, 32)
  Spatial dimensions are reduced!

Example 3: Size Inference with Pooling and Flatten#

Pooling layers reduce spatial dimensions, and Flatten layers convert multi-dimensional tensors to 1D vectors. BrainState tracks all these transformations automatically.

# MaxPool reduces spatial dimensions
pool = brainstate.nn.MaxPool2d(
    in_size=(28, 28, 32),
    kernel_size=(2, 2),
    stride=(2, 2),
    channel_axis=-1
)

print("MaxPool2d Layer:")
print(f"  in_size:  {pool.in_size}  (H=28, W=28, C=32)")
print(f"  out_size: {pool.out_size}  (H=14, W=14, C=32)")
print("  Spatial dimensions reduced by 2x!")

# Flatten converts to 1D
flatten = brainstate.nn.Flatten(in_size=(14, 14, 32))

print(f"\nFlatten Layer:")
print(f"  in_size:  {flatten.in_size}  (3D tensor)")
print(f"  out_size: {flatten.out_size}  (1D vector)")
print(f"  Total elements: {14 * 14 * 32} = {flatten.out_size[0]}")
MaxPool2d Layer:
  in_size:  (28, 28, 32)  (H=28, W=28, C=32)
  out_size: (14, 14, 32)  (H=14, W=14, C=32)
  Spatial dimensions reduced by 2x!

Flatten Layer:
  in_size:  (14, 14, 32)  (3D tensor)
  out_size: (6272,)  (1D vector)
  Total elements: 6272 = 6272

5. Sequential Composition and Deep Networks#

brainstate.nn.Sequential is a powerful container that chains multiple modules together. It automatically propagates out_size from one layer to the in_size of the next layer, enabling effortless construction of deep networks.

The .desc() Pattern#

For layers that need to infer their in_size from the previous layer, BrainState provides the .desc() method, which creates a layer descriptor that will be instantiated when the input size becomes available.

# Instead of:
brainstate.nn.Linear(in_size=(10,), out_size=(5,))

# Use descriptor in Sequential:
brainstate.nn.Linear.desc(out_size=5)  # in_size will be inferred!

Example 1: Simple Sequential Network#

# Build a simple MLP with Sequential
brainstate.random.seed(42)

mlp = brainstate.nn.Sequential(
    brainstate.nn.Linear((10,), (64,)),        # First layer needs explicit in_size
    brainstate.nn.ReLU(),                      # Element-wise, preserves shape
    brainstate.nn.Linear.desc(out_size=32),    # in_size inferred from previous layer
    brainstate.nn.ReLU(),
    brainstate.nn.Linear.desc(out_size=5)      # Final output layer
)

print("Sequential MLP:")
print(mlp)
print(f"\nInput size:  {mlp.in_size}")
print(f"Output size: {mlp.out_size}")

# Test forward pass
x = brainstate.random.randn(8, 10)  # batch of 8 samples
y = mlp(x)
print(f"\nForward pass:")
print(f"  Input:  {x.shape}")
print(f"  Output: {y.shape}")
Sequential MLP:
Sequential(
  in_size=(10,),
  out_size=(5,),
  layers=[
    Linear(
      in_size=(10,),
      out_size=(64,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[64]),
          'weight': ShapedArray(float32[10,64])
        }
      )
    ),
    ReLU(),
    Linear(
      in_size=(64,),
      out_size=(32,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[32]),
          'weight': ShapedArray(float32[64,32])
        }
      )
    ),
    ReLU(),
    Linear(
      in_size=(32,),
      out_size=(5,),
      w_mask=None,
      weight=ParamState(
        value={
          'bias': ShapedArray(float32[5]),
          'weight': ShapedArray(float32[32,5])
        }
      )
    )
  ]
)

Input size:  (10,)
Output size: (5,)

Forward pass:
  Input:  (8, 10)
  Output: (8, 5)

Example 2: CNN Network with Automatic Size Propagation#

Let’s build a complete CNN for image classification, demonstrating how in_size and out_size propagate through convolutional, pooling, flattening, and fully-connected layers.

class CNNNet(brainstate.nn.Module):
    """Convolutional Neural Network for image classification."""
    
    def __init__(self, in_size):
        super().__init__()
        self.layer = brainstate.nn.Sequential(
            # Convolutional block 1
            brainstate.nn.Conv2d(in_size, out_channels=32, kernel_size=(3, 3), 
                               stride=(1, 1), padding='SAME'),
            brainstate.nn.ReLU(),
            brainstate.nn.MaxPool2d.desc(kernel_size=(2, 2), stride=(2, 2), channel_axis=-1),
            
            # Convolutional block 2
            brainstate.nn.Conv2d.desc(out_channels=64, kernel_size=(3, 3), 
                                    stride=(1, 1), padding='SAME'),
            brainstate.nn.ReLU(),
            brainstate.nn.MaxPool2d.desc(kernel_size=(2, 2), stride=(2, 2), channel_axis=-1),
            
            # Flatten and fully-connected layers
            brainstate.nn.Flatten.desc(),
            brainstate.nn.Linear.desc(out_size=1024),
            brainstate.nn.ReLU(),
            brainstate.nn.Linear.desc(out_size=512),
            brainstate.nn.ReLU(),
            brainstate.nn.Linear.desc(out_size=10)
        )

    def update(self, x):
        return self.layer(x)

# Create CNN with image size (28, 28, 3)
example_image = brainstate.random.normal(size=(28, 28, 3))
cnn = CNNNet(example_image.shape)

print("CNN Network Architecture:")
print(cnn)
print(f"\nNetwork input size:  {cnn.in_size}")
print(f"Network output size: {cnn.out_size}")

# Trace size transformations through the network
print("\n" + "="*60)
print("Size transformations through the network:")
print("="*60)
for i, layer in enumerate(cnn.layer.layers):
    if hasattr(layer, 'in_size') and hasattr(layer, 'out_size'):
        print(f"Layer {i:2d} ({layer.__class__.__name__:15s}): "
              f"{str(layer.in_size):20s} -> {str(layer.out_size):20s}")
CNN Network Architecture:
CNNNet(
  layer=Sequential(
    in_size=(28, 28, 3),
    out_size=(10,),
    layers=[
      Conv2d(
        in_size=(28, 28, 3),
        out_size=(28, 28, 32),
        channel_first=False,
        channels_last=True,
        in_channels=3,
        out_channels=32,
        stride=(1, 1),
        kernel_size=(3, 3),
        lhs_dilation=(1, 1),
        rhs_dilation=(1, 1),
        groups=1,
        dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)),
        padding=SAME,
        kernel_shape=(3, 3, 3, 32),
        w_mask=None,
        w_initializer=XavierNormal(
          scale=1.0,
          mode='fan_avg',
          in_axis=-2,
          out_axis=-1,
          distribution='truncated_normal',
          rng=RandomState([1825841970 3512247751]),
          unit=Unit(10.0^0)
        ),
        b_initializer=None,
        weight=ParamState(
          value={
            'weight': ShapedArray(float32[3,3,3,32])
          }
        )
      ),
      ReLU(),
      MaxPool2d(
        in_size=(28, 28, 32),
        out_size=(14, 14, 32),
        init_value=-inf,
        computation=<function max at 0x0000012125AD8360>,
        pool_dim=2,
        return_indices=False,
        kernel_size=(2, 2),
        stride=(2, 2),
        padding=VALID,
        channel_axis=-1
      ),
      Conv2d(
        in_size=(14, 14, 32),
        out_size=(14, 14, 64),
        channel_first=False,
        channels_last=True,
        in_channels=32,
        out_channels=64,
        stride=(1, 1),
        kernel_size=(3, 3),
        lhs_dilation=(1, 1),
        rhs_dilation=(1, 1),
        groups=1,
        dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)),
        padding=SAME,
        kernel_shape=(3, 3, 32, 64),
        w_mask=None,
        w_initializer=XavierNormal(
          scale=1.0,
          mode='fan_avg',
          in_axis=-2,
          out_axis=-1,
          distribution='truncated_normal',
          rng=RandomState([1825841970 3512247751]),
          unit=Unit(10.0^0)
        ),
        b_initializer=None,
        weight=ParamState(
          value={
            'weight': ShapedArray(float32[3,3,32,64])
          }
        )
      ),
      ReLU(),
      MaxPool2d(
        in_size=(14, 14, 64),
        out_size=(7, 7, 64),
        init_value=-inf,
        computation=<function max at 0x0000012125AD8360>,
        pool_dim=2,
        return_indices=False,
        kernel_size=(2, 2),
        stride=(2, 2),
        padding=VALID,
        channel_axis=-1
      ),
      Flatten(
        in_size=(7, 7, 64),
        out_size=(3136,),
        start_axis=0,
        end_axis=-1
      ),
      Linear(
        in_size=(3136,),
        out_size=(1024,),
        w_mask=None,
        weight=ParamState(
          value={
            'bias': ShapedArray(float32[1024]),
            'weight': ShapedArray(float32[3136,1024])
          }
        )
      ),
      ReLU(),
      Linear(
        in_size=(1024,),
        out_size=(512,),
        w_mask=None,
        weight=ParamState(
          value={
            'bias': ShapedArray(float32[512]),
            'weight': ShapedArray(float32[1024,512])
          }
        )
      ),
      ReLU(),
      Linear(
        in_size=(512,),
        out_size=(10,),
        w_mask=None,
        weight=ParamState(
          value={
            'bias': ShapedArray(float32[10]),
            'weight': ShapedArray(float32[512,10])
          }
        )
      )
    ]
  )
)

Network input size:  None
Network output size: None

============================================================
Size transformations through the network:
============================================================
Layer  0 (Conv2d         ): (28, 28, 3)          -> (28, 28, 32)        
Layer  1 (ReLU           ): None                 -> None                
Layer  2 (MaxPool2d      ): (28, 28, 32)         -> (14, 14, 32)        
Layer  3 (Conv2d         ): (14, 14, 32)         -> (14, 14, 64)        
Layer  4 (ReLU           ): None                 -> None                
Layer  5 (MaxPool2d      ): (14, 14, 64)         -> (7, 7, 64)          
Layer  6 (Flatten        ): (7, 7, 64)           -> (3136,)             
Layer  7 (Linear         ): (3136,)              -> (1024,)             
Layer  8 (ReLU           ): None                 -> None                
Layer  9 (Linear         ): (1024,)              -> (512,)              
Layer 10 (ReLU           ): None                 -> None                
Layer 11 (Linear         ): (512,)               -> (10,)               

Example 3: Forward Pass Through CNN#

Let’s actually run a forward pass and see how data flows through the network.

# Create a batch of images
batch_size = 4
batch_images = brainstate.random.normal(size=(batch_size, 28, 28, 3))

print(f"Input batch shape: {batch_images.shape}")
print(f"  (batch_size, height, width, channels) = ({batch_size}, 28, 28, 3)")

# Forward pass
output = cnn(batch_images)

print(f"\nOutput shape: {output.shape}")
print(f"  (batch_size, num_classes) = ({batch_size}, 10)")
print(f"\nOutput logits for first sample:")
print(output[0])
Input batch shape: (4, 28, 28, 3)
  (batch_size, height, width, channels) = (4, 28, 28, 3)

Output shape: (4, 10)
  (batch_size, num_classes) = (4, 10)

Output logits for first sample:
[ 0.13889284  0.49220082 -0.6353385   0.36826375 -0.55741405 -0.22296685
  1.5445015   0.7295152   0.04205686 -0.02874903]

Benefits of Automatic Size Inference#

The automatic in_size/out_size inference system provides several key advantages:

  1. 🎯 No manual dimension calculations: You don’t need to compute output sizes after each layer

  2. 🔧 Easy architecture modifications: Change one layer without updating all subsequent layers

  3. 🐛 Early error detection: Shape mismatches are caught at construction time

  4. 📊 Built-in documentation: Network architecture is self-documenting with size information

  5. 🚀 Rapid prototyping: Quickly experiment with different architectures

Key Pattern: .desc() for Layer Descriptors#

When building networks with Sequential, use the .desc() pattern for all layers except the first:

brainstate.nn.Sequential(
    FirstLayer(in_size, ...),         # Explicit in_size
    SecondLayer.desc(...),            # in_size inferred
    ThirdLayer.desc(...),             # in_size inferred
    # ...
)

This pattern ensures that:

  • The first layer knows the input size

  • All subsequent layers automatically infer their input sizes

  • The network construction is clean and maintainable

Example 4: Complex Architecture with Mixed Layer Types#

# Build a more complex network with different layer types
class ComplexNet(brainstate.nn.Module):
    """Complex network demonstrating various layer types."""
    
    def __init__(self, in_size):
        super().__init__()
        
        self.features = brainstate.nn.Sequential(
            # Initial conv block
            brainstate.nn.Conv2d(in_size, out_channels=16, kernel_size=3, padding='SAME'),
            brainstate.nn.ReLU(),
            
            # Strided conv (reduces spatial size)
            brainstate.nn.Conv2d.desc(out_channels=32, kernel_size=3, stride=2, padding='SAME'),
            brainstate.nn.ReLU(),
            
            # Another conv + pool
            brainstate.nn.Conv2d.desc(out_channels=64, kernel_size=3, padding='SAME'),
            brainstate.nn.ReLU(),
            brainstate.nn.MaxPool2d.desc(kernel_size=(2, 2), stride=(2, 2), channel_axis=-1),
        )
        
        self.classifier = brainstate.nn.Sequential(
            brainstate.nn.Flatten(in_size=self.features.out_size),
            brainstate.nn.Linear.desc(out_size=256),
            brainstate.nn.ReLU(),
            brainstate.nn.Linear.desc(out_size=10),
        )
    
    def update(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Create network
net = ComplexNet(in_size=(32, 32, 3))

print("Complex Network:")
print(f"Input size: {net.features.in_size}")
print(f"After features: {net.features.out_size}")
print(f"After flatten: {net.classifier.layers[0].out_size}")
print(f"Final output: {net.classifier.out_size}")

# Test
x = brainstate.random.randn(2, 32, 32, 3)
y = net(x)
print(f"\nForward pass: {x.shape} -> {y.shape}")
Complex Network:
Input size: (32, 32, 3)
After features: (8, 8, 64)
After flatten: (4096,)
Final output: (10,)

Forward pass: (2, 32, 32, 3) -> (2, 10)