Other Transforms#

This tutorial covers essential utilities in brainstate.transform for introspection, optimization, and debugging:

  1. checkpoint: Memory-efficient gradient computation through rematerialization

  2. make_jaxpr and StatefulFunction: Inspect and understand compiled computation graphs

  3. jax.debug.print: Runtime debugging in JIT-compiled code

All examples demonstrate state-aware features that distinguish BrainState from vanilla JAX.

Imports and Setup#

import jax
import jax.numpy as jnp
import brainstate
from brainstate.transform import checkpoint, make_jaxpr, StatefulFunction

1. checkpoint: Memory-Efficient Gradient Computation#

checkpoint (also known as rematerialization or gradient checkpointing) is crucial for training deep neural networks and processing long sequences. It trades computation for memory during backpropagation.

How Gradient Computation Works#

Without checkpointing:

  • Forward pass: Computes outputs and stores all intermediate activations

  • Backward pass: Uses stored activations to compute gradients

  • Memory usage: O(n) where n is the number of layers/steps

With checkpointing:

  • Forward pass: Computes outputs, stores only inputs at checkpoints

  • Backward pass: Recomputes intermediate activations from checkpoints as needed

  • Memory usage: O(√n) with optimal checkpointing

  • Computation: ~2x forward passes (recomputation during backward)

Key principle: Trade extra computation for reduced memory

1.1 Basic Usage with Gradient Computation#

# Example 1: Memory-efficient gradient computation
print("=== Example 1: Basic checkpoint usage ===")

# Without checkpoint: stores all intermediate activations
def expensive_forward(x):
    """Chain of expensive operations."""
    y = jnp.sin(x)
    z = jnp.exp(y)
    w = jnp.tanh(z)
    return jnp.sum(w ** 2)

# With checkpoint: only stores inputs, recomputes during backward
@checkpoint
def checkpointed_forward(x):
    """Same computation, but memory-efficient."""
    y = jnp.sin(x)
    z = jnp.exp(y)
    w = jnp.tanh(z)
    return jnp.sum(w ** 2)

x = jnp.linspace(0, 10, 1000)

# Both produce same results
value1, grad1 = jax.value_and_grad(expensive_forward)(x)
value2, grad2 = jax.value_and_grad(checkpointed_forward)(x)

print(f"Values match: {jnp.allclose(value1, value2)}")
print(f"Gradients match: {jnp.allclose(grad1, grad2)}")
print(f"\nMemory: checkpoint saves ~3x intermediate activations")
print(f"Cost: checkpoint does ~2x forward computations")
=== Example 1: Basic checkpoint usage ===
Values match: True
Gradients match: True

Memory: checkpoint saves ~3x intermediate activations
Cost: checkpoint does ~2x forward computations

1.2 Checkpointing Stateful Computations#

BrainState’s checkpoint properly handles State objects during gradient computation.

# Example 2: Checkpoint with stateful neural network
print("\n=== Example 2: Checkpointed neural network ===")

class DeepNetwork(brainstate.nn.Module):
    """Deep network with many layers."""
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = []
        for i in range(len(layer_sizes) - 1):
            self.layers.append(
                brainstate.ParamState(jax.random.normal(
                    jax.random.PRNGKey(i), 
                    (layer_sizes[i], layer_sizes[i+1])
                ))
            )
    
    def forward(self, x, use_checkpoint=False):
        """Forward pass through all layers."""
        def layer_fn(x):
            h = x
            for W in self.layers[:-1]:
                h = jnp.tanh(h @ W.value)
            # Output layer (no activation)
            return h @ self.layers[-1].value
        
        if use_checkpoint:
            return checkpoint(layer_fn)(x)
        else:
            return layer_fn(x)

# Create a deep network: 10 layers
net = DeepNetwork([128, 256, 256, 256, 256, 256, 256, 256, 256, 128, 10])
x_batch = jax.random.normal(jax.random.PRNGKey(42), (32, 128))

# Define loss function
def loss_fn(use_checkpoint):
    y_pred = net.forward(x_batch, use_checkpoint=use_checkpoint)
    return jnp.mean(y_pred ** 2)

# Get parameters
params = net.states(brainstate.ParamState)

# Compute gradients with and without checkpoint
grads_normal = brainstate.transform.grad(lambda: loss_fn(False), params)()
grads_checkpointed = brainstate.transform.grad(lambda: loss_fn(True), params)()

# Compare
print(f"Number of layers: {len(net.layers)}")
print(f"Gradient shapes match: {jax.tree.map(lambda a, b: a.shape == b.shape, grads_normal, grads_checkpointed)}")
print(f"\nWithout checkpoint: Stores ~10 layer activations")
print(f"With checkpoint: Recomputes activations during backward")
print(f"Memory saved: ~10x for deep networks")
=== Example 2: Checkpointed neural network ===
Number of layers: 10
Gradient shapes match: {('layers', 0): True, ('layers', 1): True, ('layers', 2): True, ('layers', 3): True, ('layers', 4): True, ('layers', 5): True, ('layers', 6): True, ('layers', 7): True, ('layers', 8): True, ('layers', 9): True}

Without checkpoint: Stores ~10 layer activations
With checkpoint: Recomputes activations during backward
Memory saved: ~10x for deep networks

1.3 Sequential Layer Checkpointing#

For very deep networks, checkpoint individual layers or groups of layers.

# Example 3: Per-layer checkpointing
print("\n=== Example 3: Granular checkpointing ===")

class CheckpointedDeepNetwork(brainstate.nn.Module):
    """Network with per-layer checkpointing."""
    def __init__(self, layer_sizes, checkpoint_every=2):
        super().__init__()
        self.checkpoint_every = checkpoint_every
        self.weights = []
        for i in range(len(layer_sizes) - 1):
            self.weights.append(
                brainstate.ParamState(jax.random.normal(
                    jax.random.PRNGKey(i), 
                    (layer_sizes[i], layer_sizes[i+1])
                ) * 0.1)
            )
    
    def __call__(self, x):
        h = x
        for i, W in enumerate(self.weights):
            # Define layer computation
            def layer_forward(h):
                return jnp.tanh(h @ W.value)
            
            # Checkpoint every N layers
            if (i + 1) % self.checkpoint_every == 0:
                h = checkpoint(layer_forward)(h)
            else:
                h = layer_forward(h)
        return h

# Create network: checkpoint every 2 layers
ckpt_net = CheckpointedDeepNetwork(
    [64, 128, 128, 128, 128, 128, 32],  # 6 layers
    checkpoint_every=2
)

x_in = jax.random.normal(jax.random.PRNGKey(123), (16, 64))

# Forward and backward
def forward_loss():
    return jnp.sum(ckpt_net(x_in) ** 2)

grads, value = brainstate.transform.grad(
    forward_loss, 
    ckpt_net.states(brainstate.ParamState),
    return_value=True
)()

print(f"Network depth: {len(ckpt_net.weights)} layers")
print(f"Checkpoint frequency: every {ckpt_net.checkpoint_every} layers")
print(f"Checkpoints created: {len(ckpt_net.weights) // ckpt_net.checkpoint_every}")
print(f"Loss: {value:.4f}")
print(f"\nMemory usage: O(checkpoints) instead of O(layers)")
=== Example 3: Granular checkpointing ===
Network depth: 6 layers
Checkpoint frequency: every 2 layers
Checkpoints created: 3
Loss: 85.1143

Memory usage: O(checkpoints) instead of O(layers)

1.4 Memory-Computation Tradeoff#

Understand when to use checkpointing.

# Example 5: Measuring the tradeoff
print("\n=== Example 5: When to use checkpoint ===")

import time

class BenchmarkNet(brainstate.nn.Module):
    def __init__(self, n_layers, hidden_size):
        super().__init__()
        self.layers = []
        for i in range(n_layers):
            self.layers.append(
                brainstate.ParamState(jax.random.normal(
                    jax.random.PRNGKey(i), 
                    (hidden_size, hidden_size)
                ) * 0.1)
            )
    
    def forward_normal(self, x):
        h = x
        for W in self.layers:
            h = jnp.tanh(h @ W.value)
        return jnp.sum(h)
    
    def forward_checkpointed(self, x):
        def layer_block(h):
            for W in self.layers:
                h = jnp.tanh(h @ W.value)
            return jnp.sum(h)
        return checkpoint(layer_block)(x)

# Small network: checkpoint overhead not worth it
small_net = BenchmarkNet(n_layers=3, hidden_size=64)
x_small = jax.random.normal(jax.random.PRNGKey(0), (64,))

# Large network: checkpoint saves significant memory
large_net = BenchmarkNet(n_layers=20, hidden_size=512)
x_large = jax.random.normal(jax.random.PRNGKey(0), (512,))

print("Small network (3 layers, 64 hidden):")
print("  → Normal gradient: Fast, low memory")
print("  → Checkpoint: Overhead not justified\n")

print("Large network (20 layers, 512 hidden):")
print("  → Normal gradient: Stores ~20 activations (high memory)")
print("  → Checkpoint: Recomputes activations (saves memory)")
print("  → Recommended: Use checkpoint for deep/wide networks\n")

print("Rule of thumb:")
print("  Use checkpoint when: depth > 10 OR width > 256")
print("  Skip checkpoint when: shallow networks (< 5 layers)")
=== Example 5: When to use checkpoint ===
Small network (3 layers, 64 hidden):
  → Normal gradient: Fast, low memory
  → Checkpoint: Overhead not justified

Large network (20 layers, 512 hidden):
  → Normal gradient: Stores ~20 activations (high memory)
  → Checkpoint: Recomputes activations (saves memory)
  → Recommended: Use checkpoint for deep/wide networks

Rule of thumb:
  Use checkpoint when: depth > 10 OR width > 256
  Skip checkpoint when: shallow networks (< 5 layers)

2. make_jaxpr and StatefulFunction: Inspecting Compiled Code#

make_jaxpr converts a function into its JAX intermediate representation (Jaxpr), which reveals how JAX compiles and optimizes your code. StatefulFunction is the underlying mechanism that enables state-aware transformations.

What is Jaxpr?#

Jaxpr is JAX’s intermediate representation based on a simply-typed first-order lambda calculus with let-bindings. It shows:

  • Primitive operations (add, mul, sin, etc.)

  • Data dependencies

  • How state reads/writes are handled

  • Memory layout and optimizations

2.1 Basic Jaxpr Inspection#

# Example 1: Simple function jaxpr
print("=== Example 1: Basic jaxpr ===")

def simple_fn(x):
    y = jnp.sin(x)
    z = jnp.cos(y)
    return z * 2

# Create jaxpr
jaxpr_fn = make_jaxpr(simple_fn)
jaxpr, states = jaxpr_fn(3.0)

print("Function: z = cos(sin(x)) * 2")
print("\nJaxpr representation:")
print(jaxpr)
print(f"\nStates used: {len(states)} (none for this simple function)")
=== Example 1: Basic jaxpr ===
Function: z = cos(sin(x)) * 2

Jaxpr representation:
{ lambda ; a:f32[]. let
    b:f32[] = sin a
    c:f32[] = cos b
    d:f32[] = mul c 2.0:f32[]
  in (d,) }

States used: 0 (none for this simple function)

2.2 Stateful Jaxpr: Tracking State Reads and Writes#

BrainState’s make_jaxpr reveals how states are accessed.

# Example 2: Jaxpr with states
print("\n=== Example 2: Stateful jaxpr ===")

# Create states
counter = brainstate.ShortTermState(jnp.array(0))
accumulator = brainstate.ShortTermState(jnp.array(0.0))

def stateful_fn(x):
    # Read states
    count = counter.value
    accum = accumulator.value
    
    # Update states
    counter.value = count + 1
    accumulator.value = accum + x
    
    return accumulator.value / counter.value

# Inspect jaxpr
jaxpr_fn = make_jaxpr(stateful_fn)
jaxpr, states = jaxpr_fn(5.0)

print("Function: running average tracker")
print(f"\nStates accessed: {len(states)}")
for i, state in enumerate(states):
    print(f"  [{i}] {type(state).__name__}: {state.value}")

print("\nJaxpr (state operations visible):")
print(jaxpr)
print("\nNote: Jaxpr shows state reads as inputs, writes as outputs")
=== Example 2: Stateful jaxpr ===
Function: running average tracker

States accessed: 2
  [0] ShortTermState: 0
  [1] ShortTermState: 0.0

Jaxpr (state operations visible):
{ lambda ; a:f32[] b:i32[] c:f32[]. let
    d:i32[] = add b 1:i32[]
    e:f32[] = add c a
    f:f32[] = convert_element_type[new_dtype=float32 weak_type=True] d
    g:f32[] = div e f
  in (g, d, e) }

Note: Jaxpr shows state reads as inputs, writes as outputs

2.3 Understanding StatefulFunction#

StatefulFunction is the core abstraction that enables all BrainState transformations. It:

  1. Identifies states accessed during function execution

  2. Compiles to Jaxpr with explicit state inputs/outputs

  3. Manages state values before and after execution

  4. Caches compilations for efficient repeated calls

# Example 3: Using StatefulFunction directly
print("\n=== Example 3: StatefulFunction mechanics ===")

# Create a module with state
class NeuralCell(brainstate.nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.W = brainstate.ParamState(jax.random.normal(
            jax.random.PRNGKey(0), (input_size, hidden_size)
        ))
        self.h = brainstate.ShortTermState(jnp.zeros(hidden_size))
    
    def __call__(self, x):
        # Update hidden state
        self.h.value = jnp.tanh(x @ self.W.value + self.h.value)
        return self.h.value

cell = NeuralCell(input_size=10, hidden_size=20)

# Wrap in StatefulFunction
sf = StatefulFunction(cell)

# Example input
x = jax.random.normal(jax.random.PRNGKey(1), (10,))

# Step 1: Compile and inspect
sf.make_jaxpr(x)
print("Step 1: Compilation")
print(f"  Compiled for input shape: {x.shape}")

# Step 2: Get tracked states
states = sf.get_states(x)
read_states = sf.get_read_states(x)
write_states = sf.get_write_states(x)

print(f"\nStep 2: State identification")
print(f"  Total states: {len(states)}")
print(f"  Read states: {len(read_states)}")
for s in read_states:
    print(f"    - {type(s).__name__}: shape {s.value.shape}")
print(f"  Write states: {len(write_states)}")
for s in write_states:
    print(f"    - {type(s).__name__}: shape {s.value.shape}")

# Step 3: Get jaxpr
jaxpr = sf.get_jaxpr(x)
print(f"\nStep 3: Jaxpr compilation")
print(f"  Jaxpr variables: {len(jaxpr.jaxpr.invars)} inputs, {len(jaxpr.jaxpr.outvars)} outputs")
print(f"  Jaxpr equations: {len(jaxpr.jaxpr.eqns)} operations")

# Step 4: Execute
output = sf(x)
print(f"\nStep 4: Execution")
print(f"  Output shape: {output.shape}")
print(f"  Hidden state updated: {cell.h.value.shape}")
=== Example 3: StatefulFunction mechanics ===
Step 1: Compilation
  Compiled for input shape: (10,)

Step 2: State identification
  Total states: 2
  Read states: 1
    - ParamState: shape (10, 20)
  Write states: 1
    - ShortTermState: shape (20,)

Step 3: Jaxpr compilation
  Jaxpr variables: 3 inputs, 2 outputs
  Jaxpr equations: 3 operations

Step 4: Execution
  Output shape: (20,)
  Hidden state updated: (20,)

2.4 Jaxpr for Gradient Computation#

Inspect how autodiff transforms your code.

# Example 4: Gradient jaxpr
print("\n=== Example 4: Gradient computation jaxpr ===")

# Simple loss function
params = brainstate.ParamState(jnp.array([1.0, 2.0, 3.0]))

def loss_fn(x):
    return jnp.sum((params.value - x) ** 2)

# Original function jaxpr
print("Original function jaxpr:")
jaxpr_orig, _ = make_jaxpr(loss_fn)(jnp.array([0.5, 1.0, 1.5]))
print(jaxpr_orig)

# Gradient function jaxpr
print("\nGradient function jaxpr:")
grad_fn = brainstate.transform.grad(loss_fn, params)
jaxpr_grad, _ = make_jaxpr(grad_fn)(jnp.array([0.5, 1.0, 1.5]))
print(jaxpr_grad)

print("\nNote: Gradient jaxpr includes:")
print("  - Forward pass operations")
print("  - Backward pass (VJP) operations")
print("  - Much more complex than original")
=== Example 4: Gradient computation jaxpr ===
Original function jaxpr:
{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = sub b a
    d:f32[3] = integer_pow[y=2] c
    e:f32[] = reduce_sum[axes=(0,) out_sharding=None] d
  in (e, b) }

Gradient function jaxpr:
{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = sub b a
    d:f32[3] = integer_pow[y=2] c
    e:f32[3] = integer_pow[y=1] c
    f:f32[3] = mul 2.0:f32[] e
    _:f32[] = reduce_sum[axes=(0,) out_sharding=None] d
    g:f32[3] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(3,)
      sharding=None
    ] 1.0:f32[]
    h:f32[3] = mul g f
  in (h, b) }

Note: Gradient jaxpr includes:
  - Forward pass operations
  - Backward pass (VJP) operations
  - Much more complex than original

2.5 Jaxpr for Transformed Functions#

See how transformations affect the compiled code.

# Example 5: Transformation jaxpr
print("\n=== Example 5: Transformed function jaxpr ===")

def simple_fn(x):
    return x ** 2

# Original
print("Original function:")
jaxpr1, _ = make_jaxpr(simple_fn)(jnp.array([1.0, 2.0, 3.0]))
print(jaxpr1)

# Vmapped version
print("\nVmapped function:")
vmapped_fn = brainstate.transform.vmap2(simple_fn)
jaxpr2, _ = make_jaxpr(vmapped_fn)(jnp.array([[1.0, 2.0], [3.0, 4.0]]))
print(jaxpr2)

print("\nNote: vmap adds batching dimensions to operations")
=== Example 5: Transformed function jaxpr ===
Original function:
{ lambda ; a:f32[3]. let b:f32[3] = integer_pow[y=2] a in (b,) }

Vmapped function:
{ lambda ; a:f32[2,2]. let
    b:key<fry>[] = random_seed[impl=fry] 0:i32[]
    c:u32[2] = random_unwrap b
    d:key<fry>[] = random_wrap[impl=fry] c
    e:key<fry>[2] = random_split[shape=(2,)] d
    _:u32[2,2] = random_unwrap e
    _:f32[2,2] = integer_pow[y=2] a
    f:f32[2,2] = integer_pow[y=2] a
  in (f,) }

Note: vmap adds batching dimensions to operations

2.6 StatefulFunction Caching#

StatefulFunction caches compiled jaxprs for efficiency.

# Example 6: Understanding compilation caching
print("\n=== Example 6: Compilation caching ===")

state = brainstate.ShortTermState(jnp.array(0.0))

def cached_fn(x):
    state.value = state.value + jnp.sum(x)
    return state.value

sf = StatefulFunction(cached_fn)

# First call: compile
sf.make_jaxpr(jnp.array([1.0, 2.0]))
stats1 = sf.get_cache_stats()
print("After first compilation:")
print(f"  Jaxpr cache: {stats1['jaxpr_cache']}")

# Same shape: cache hit
sf.make_jaxpr(jnp.array([3.0, 4.0]))
stats2 = sf.get_cache_stats()
print("\nAfter same-shape call:")
print(f"  Jaxpr cache: {stats2['jaxpr_cache']}")
print(f"  Hit rate: {stats2['jaxpr_cache']['hit_rate']:.1f}%")

# Different shape: new compilation
sf.make_jaxpr(jnp.array([1.0, 2.0, 3.0]))
stats3 = sf.get_cache_stats()
print("\nAfter different-shape call:")
print(f"  Jaxpr cache: {stats3['jaxpr_cache']}")
print(f"  Cache size: {stats3['jaxpr_cache']['size']} entries")

print("\nCaching strategy:")
print("  - Different shapes → new compilation")
print("  - Same shapes → cache reuse")
print("  - Cache size limited to 128 entries (LRU)")
=== Example 6: Compilation caching ===
After first compilation:
  Jaxpr cache: {'size': 1, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}

After same-shape call:
  Jaxpr cache: {'size': 1, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}
  Hit rate: 0.0%

After different-shape call:
  Jaxpr cache: {'size': 2, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}
  Cache size: 2 entries

Caching strategy:
  - Different shapes → new compilation
  - Same shapes → cache reuse
  - Cache size limited to 128 entries (LRU)

3. Debugging with jax.debug.print#

jax.debug.print enables runtime debugging in JIT-compiled code. Unlike regular print, it:

  • Executes during runtime (not tracing)

  • Works inside @jit, vmap, grad, etc.

  • Supports formatted output

  • Can print array values and shapes

Key principle: Debug prints happen at execution time, not trace time#

3.1 Basic Debug Printing#

# Example 1: Basic debug printing in JIT
print("=== Example 1: Debug printing in JIT ===")

@brainstate.transform.jit
def compute_with_debug(x):
    jax.debug.print("Input: {x}", x=x)
    y = x ** 2
    jax.debug.print("After square: {y}", y=y)
    z = jnp.sum(y)
    jax.debug.print("Sum: {z}", z=z)
    return z

result = compute_with_debug(jnp.array([1.0, 2.0, 3.0]))
print(f"\nFinal result: {result}")
print("\nNote: Debug prints appear during execution, not compilation")
=== Example 1: Debug printing in JIT ===
Input: [1. 2. 3.]
After square: [1. 4. 9.]
Sum: 14.0

Final result: 14.0

Note: Debug prints appear during execution, not compilation

3.2 Debugging State Updates#

# Example 2: Debugging stateful computations
print("\n=== Example 2: Debug state updates ===")

class DebuggableCell(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.state = brainstate.ShortTermState(jnp.zeros(size))
        self.weight = brainstate.ParamState(jax.random.normal(jax.random.PRNGKey(0), (size, size)))
    
    def step(self, x):
        jax.debug.print("Before update - state: {s}", s=self.state.value)
        
        # Update
        new_state = jnp.tanh(x @ self.weight.value + self.state.value)
        jax.debug.print("Computed new state: {s}", s=new_state)
        
        self.state.value = new_state
        jax.debug.print("After update - state: {s}", s=self.state.value)
        
        return new_state

cell = DebuggableCell(size=3)

@brainstate.transform.jit
def update_step(x):
    return cell.step(x)

x = jnp.array([1.0, 0.0, -1.0])
output = update_step(x)
print(f"\nOutput: {output}")
=== Example 2: Debug state updates ===
Before update - state: [0. 0. 0.]
Computed new state: [ 0.97147846  0.9105761  -0.79975927]
After update - state: [ 0.97147846  0.9105761  -0.79975927]

Output: [ 0.97147846  0.9105761  -0.79975927]

3.3 Debugging Gradients#

# Example 3: Debug gradient computation
print("\n=== Example 3: Debug gradients ===")

param = brainstate.ParamState(jnp.array([2.0, 3.0]))

def loss_with_debug(x):
    jax.debug.print("Forward - param: {p}, input: {x}", p=param.value, x=x)
    
    pred = param.value * x
    jax.debug.print("Forward - prediction: {pred}", pred=pred)
    
    loss = jnp.sum(pred ** 2)
    jax.debug.print("Forward - loss: {loss}", loss=loss)
    
    return loss

# Gradient computation
x = jnp.array([0.5, 1.0])
grad_fn = brainstate.transform.grad(loss_with_debug, param)

print("\nComputing gradients:")
grads = grad_fn(x)
print(f"\nGradients: {grads}")
print("\nNote: Debug prints show forward pass values during gradient computation")
=== Example 3: Debug gradients ===

Computing gradients:
Forward - param: [2. 3.], input: [0.5 1. ]
Forward - prediction: [1. 3.]
Forward - loss: 10.0

Gradients: [1. 6.]

Note: Debug prints show forward pass values during gradient computation

3.4 Debugging Vectorized Code#

# Example 4: Debug vmap
print("\n=== Example 4: Debug vectorized code ===")

def process_item(x, index):
    jax.debug.print("Processing item {i}: {x}", i=index, x=x)
    return x ** 2

# Vmap over both arguments
vmapped_fn = brainstate.transform.vmap2(process_item)

batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
indices = jnp.arange(len(batch_x))

print("\nProcessing batch:")
results = vmapped_fn(batch_x, indices)
print(f"\nResults: {results}")
print("\nNote: Debug prints execute for each element in the batch")
=== Example 4: Debug vectorized code ===

Processing batch:
Processing item 0: 1.0
Processing item 1: 2.0
Processing item 2: 3.0
Processing item 3: 4.0
Processing item 0: 1.0
Processing item 1: 2.0
Processing item 2: 3.0
Processing item 3: 4.0

Results: [ 1.  4.  9. 16.]

Note: Debug prints execute for each element in the batch

3.5 Conditional Debugging#

# Example 5: Conditional debug prints
print("\n=== Example 5: Conditional debugging ===")

iteration = brainstate.ShortTermState(jnp.array(0))

def training_step_with_debug(x, debug_every=5):
    # Update iteration
    iteration.value = iteration.value + 1
    
    # Conditional debug print
    jax.debug.print(
        "Iteration {iter}: x={x}",
        iter=iteration.value,
        x=x,
    )
    
    loss = jnp.sum(x ** 2)
    return loss

@brainstate.transform.jit
def train_step(x):
    return training_step_with_debug(x, debug_every=3)

print("\nRunning 10 training steps:")
for i in range(10):
    x = jax.random.normal(jax.random.PRNGKey(i), (5,))
    loss = train_step(x)

print("\nNote: Debug prints only at iterations 3, 6, 9")
=== Example 5: Conditional debugging ===

Running 10 training steps:
Iteration 1: x=[ 1.6226422   2.0252647  -0.43359444 -0.07861735  0.1760909 ]
Iteration 2: x=[-0.15443718  0.08470728 -0.13598049 -0.15503626  1.2666674 ]
Iteration 3: x=[ 0.36057416  1.2849895  -0.73873436  1.1830745  -0.20641916]
Iteration 4: x=[-1.446257    1.539381    0.38250625  1.9707018  -0.5876674 ]
Iteration 5: x=[ 1.1777242   0.73848104 -1.0801564   0.3344669   0.00339968]
Iteration 6: x=[-0.08437306  1.4110229   0.63048154 -1.3100973   1.3689315 ]
Iteration 7: x=[ 0.3864717  -0.57079715 -1.678261   -1.203193    1.0770401 ]
Iteration 8: x=[ 0.45123515  1.9534509  -0.51623946 -0.1409403   0.6154967 ]
Iteration 9: x=[-0.55150557 -1.369112    2.7549403   0.5639917  -1.0112009 ]
Iteration 10: x=[-1.7417272   1.8461128  -0.20227258 -1.27005    -0.7593621 ]

Note: Debug prints only at iterations 3, 6, 9

3.6 Advanced: Custom Debug Callbacks#

# Example 6: Custom debugging with callbacks
print("\n=== Example 6: Custom debug callbacks ===")

def custom_debug_callback(name, value):
    """Custom callback for detailed debugging."""
    print(f"[DEBUG {name}]:")
    print(f"  Shape: {value.shape}")
    print(f"  Dtype: {value.dtype}")
    print(f"  Min: {jnp.min(value):.4f}")
    print(f"  Max: {jnp.max(value):.4f}")
    print(f"  Mean: {jnp.mean(value):.4f}")
    print(f"  Std: {jnp.std(value):.4f}")

@brainstate.transform.jit
def compute_with_callback(x):
    # Use debug callback for detailed inspection
    jax.debug.callback(custom_debug_callback, "input", x)
    
    y = jnp.tanh(x)
    jax.debug.callback(custom_debug_callback, "after_tanh", y)
    
    z = y @ y.T
    jax.debug.callback(custom_debug_callback, "output", z)
    
    return z

x = jax.random.normal(jax.random.PRNGKey(42), (5, 5))
print("\nExecuting with custom debug callbacks:")
result = compute_with_callback(x)
print(f"\nFinal result shape: {result.shape}")
=== Example 6: Custom debug callbacks ===

Executing with custom debug callbacks:
[DEBUG input]:
  Shape: (5, 5)
  Dtype: float32
  Min: -1.9389
  Max: 1.4458
  Mean: 0.1084
  Std: 0.8442
[DEBUG after_tanh]:
  Shape: (5, 5)
  Dtype: float32
  Min: -0.9594
  Max: 0.8949
  Mean: 0.1294
  Std: 0.5656
[DEBUG output]:
  Shape: (5, 5)
  Dtype: float32
  Min: -1.2547
  Max: 2.2312
  Mean: 0.1165
  Std: 0.9866

Final result shape: (5, 5)

Summary#

This tutorial covered three essential BrainState utilities:

1. checkpoint: Memory-Efficient Gradients#

  • Purpose: Reduce memory usage during gradient computation

  • Mechanism: Recompute activations during backward pass instead of storing them

  • Tradeoff: ~2x computation for significant memory savings (O(√n) vs O(n))

  • When to use: Deep networks (>10 layers), wide networks (>256 hidden), long sequences

  • Advanced: Custom policies control what to save vs. recompute

  • State-aware: Works seamlessly with BrainState’s State objects

2. make_jaxpr and StatefulFunction: Code Inspection#

  • Purpose: Understand how JAX compiles and optimizes your code

  • Jaxpr: JAX’s intermediate representation showing primitive operations and data flow

  • StatefulFunction: Core mechanism enabling all BrainState transformations

    • Identifies state reads and writes

    • Compiles to Jaxpr with explicit state handling

    • Caches compilations for efficiency (LRU cache, 128 entries)

    • Manages state values automatically

  • Use cases: Debugging compilation issues, understanding transformations, optimization analysis

3. jax.debug.print: Runtime Debugging#

  • Purpose: Debug JIT-compiled code during execution

  • Key features:

    • Prints at runtime (not trace time)

    • Works inside @jit, vmap, grad, etc.

    • Supports formatted output and array inspection

  • Best practices:

    • Use debug flags to enable/disable

    • Print statistics, not full arrays

    • Check for NaN/Inf in critical ops

    • Use callbacks for complex debugging

    • Disable in production

Integration with BrainState#

All three tools are state-aware:

  • checkpoint preserves state semantics during rematerialization

  • make_jaxpr reveals state reads/writes in compiled code

  • jax.debug.print can inspect state values during execution

These utilities are essential for developing, optimizing, and debugging complex stateful models in BrainState.