Other Transforms#
This tutorial covers essential utilities in brainstate.transform for introspection, optimization, and debugging:
checkpoint: Memory-efficient gradient computation through rematerializationmake_jaxprandStatefulFunction: Inspect and understand compiled computation graphsjax.debug.print: Runtime debugging in JIT-compiled code
All examples demonstrate state-aware features that distinguish BrainState from vanilla JAX.
Imports and Setup#
import jax
import jax.numpy as jnp
import brainstate
from brainstate.transform import checkpoint, make_jaxpr, StatefulFunction
1. checkpoint: Memory-Efficient Gradient Computation#
checkpoint (also known as rematerialization or gradient checkpointing) is crucial for training deep neural networks and processing long sequences. It trades computation for memory during backpropagation.
How Gradient Computation Works#
Without checkpointing:
Forward pass: Computes outputs and stores all intermediate activations
Backward pass: Uses stored activations to compute gradients
Memory usage: O(n) where n is the number of layers/steps
With checkpointing:
Forward pass: Computes outputs, stores only inputs at checkpoints
Backward pass: Recomputes intermediate activations from checkpoints as needed
Memory usage: O(√n) with optimal checkpointing
Computation: ~2x forward passes (recomputation during backward)
Key principle: Trade extra computation for reduced memory
1.1 Basic Usage with Gradient Computation#
# Example 1: Memory-efficient gradient computation
print("=== Example 1: Basic checkpoint usage ===")
# Without checkpoint: stores all intermediate activations
def expensive_forward(x):
"""Chain of expensive operations."""
y = jnp.sin(x)
z = jnp.exp(y)
w = jnp.tanh(z)
return jnp.sum(w ** 2)
# With checkpoint: only stores inputs, recomputes during backward
@checkpoint
def checkpointed_forward(x):
"""Same computation, but memory-efficient."""
y = jnp.sin(x)
z = jnp.exp(y)
w = jnp.tanh(z)
return jnp.sum(w ** 2)
x = jnp.linspace(0, 10, 1000)
# Both produce same results
value1, grad1 = jax.value_and_grad(expensive_forward)(x)
value2, grad2 = jax.value_and_grad(checkpointed_forward)(x)
print(f"Values match: {jnp.allclose(value1, value2)}")
print(f"Gradients match: {jnp.allclose(grad1, grad2)}")
print(f"\nMemory: checkpoint saves ~3x intermediate activations")
print(f"Cost: checkpoint does ~2x forward computations")
=== Example 1: Basic checkpoint usage ===
Values match: True
Gradients match: True
Memory: checkpoint saves ~3x intermediate activations
Cost: checkpoint does ~2x forward computations
1.2 Checkpointing Stateful Computations#
BrainState’s checkpoint properly handles State objects during gradient computation.
# Example 2: Checkpoint with stateful neural network
print("\n=== Example 2: Checkpointed neural network ===")
class DeepNetwork(brainstate.nn.Module):
"""Deep network with many layers."""
def __init__(self, layer_sizes):
super().__init__()
self.layers = []
for i in range(len(layer_sizes) - 1):
self.layers.append(
brainstate.ParamState(jax.random.normal(
jax.random.PRNGKey(i),
(layer_sizes[i], layer_sizes[i+1])
))
)
def forward(self, x, use_checkpoint=False):
"""Forward pass through all layers."""
def layer_fn(x):
h = x
for W in self.layers[:-1]:
h = jnp.tanh(h @ W.value)
# Output layer (no activation)
return h @ self.layers[-1].value
if use_checkpoint:
return checkpoint(layer_fn)(x)
else:
return layer_fn(x)
# Create a deep network: 10 layers
net = DeepNetwork([128, 256, 256, 256, 256, 256, 256, 256, 256, 128, 10])
x_batch = jax.random.normal(jax.random.PRNGKey(42), (32, 128))
# Define loss function
def loss_fn(use_checkpoint):
y_pred = net.forward(x_batch, use_checkpoint=use_checkpoint)
return jnp.mean(y_pred ** 2)
# Get parameters
params = net.states(brainstate.ParamState)
# Compute gradients with and without checkpoint
grads_normal = brainstate.transform.grad(lambda: loss_fn(False), params)()
grads_checkpointed = brainstate.transform.grad(lambda: loss_fn(True), params)()
# Compare
print(f"Number of layers: {len(net.layers)}")
print(f"Gradient shapes match: {jax.tree.map(lambda a, b: a.shape == b.shape, grads_normal, grads_checkpointed)}")
print(f"\nWithout checkpoint: Stores ~10 layer activations")
print(f"With checkpoint: Recomputes activations during backward")
print(f"Memory saved: ~10x for deep networks")
=== Example 2: Checkpointed neural network ===
Number of layers: 10
Gradient shapes match: {('layers', 0): True, ('layers', 1): True, ('layers', 2): True, ('layers', 3): True, ('layers', 4): True, ('layers', 5): True, ('layers', 6): True, ('layers', 7): True, ('layers', 8): True, ('layers', 9): True}
Without checkpoint: Stores ~10 layer activations
With checkpoint: Recomputes activations during backward
Memory saved: ~10x for deep networks
1.3 Sequential Layer Checkpointing#
For very deep networks, checkpoint individual layers or groups of layers.
# Example 3: Per-layer checkpointing
print("\n=== Example 3: Granular checkpointing ===")
class CheckpointedDeepNetwork(brainstate.nn.Module):
"""Network with per-layer checkpointing."""
def __init__(self, layer_sizes, checkpoint_every=2):
super().__init__()
self.checkpoint_every = checkpoint_every
self.weights = []
for i in range(len(layer_sizes) - 1):
self.weights.append(
brainstate.ParamState(jax.random.normal(
jax.random.PRNGKey(i),
(layer_sizes[i], layer_sizes[i+1])
) * 0.1)
)
def __call__(self, x):
h = x
for i, W in enumerate(self.weights):
# Define layer computation
def layer_forward(h):
return jnp.tanh(h @ W.value)
# Checkpoint every N layers
if (i + 1) % self.checkpoint_every == 0:
h = checkpoint(layer_forward)(h)
else:
h = layer_forward(h)
return h
# Create network: checkpoint every 2 layers
ckpt_net = CheckpointedDeepNetwork(
[64, 128, 128, 128, 128, 128, 32], # 6 layers
checkpoint_every=2
)
x_in = jax.random.normal(jax.random.PRNGKey(123), (16, 64))
# Forward and backward
def forward_loss():
return jnp.sum(ckpt_net(x_in) ** 2)
grads, value = brainstate.transform.grad(
forward_loss,
ckpt_net.states(brainstate.ParamState),
return_value=True
)()
print(f"Network depth: {len(ckpt_net.weights)} layers")
print(f"Checkpoint frequency: every {ckpt_net.checkpoint_every} layers")
print(f"Checkpoints created: {len(ckpt_net.weights) // ckpt_net.checkpoint_every}")
print(f"Loss: {value:.4f}")
print(f"\nMemory usage: O(checkpoints) instead of O(layers)")
=== Example 3: Granular checkpointing ===
Network depth: 6 layers
Checkpoint frequency: every 2 layers
Checkpoints created: 3
Loss: 85.1143
Memory usage: O(checkpoints) instead of O(layers)
1.4 Memory-Computation Tradeoff#
Understand when to use checkpointing.
# Example 5: Measuring the tradeoff
print("\n=== Example 5: When to use checkpoint ===")
import time
class BenchmarkNet(brainstate.nn.Module):
def __init__(self, n_layers, hidden_size):
super().__init__()
self.layers = []
for i in range(n_layers):
self.layers.append(
brainstate.ParamState(jax.random.normal(
jax.random.PRNGKey(i),
(hidden_size, hidden_size)
) * 0.1)
)
def forward_normal(self, x):
h = x
for W in self.layers:
h = jnp.tanh(h @ W.value)
return jnp.sum(h)
def forward_checkpointed(self, x):
def layer_block(h):
for W in self.layers:
h = jnp.tanh(h @ W.value)
return jnp.sum(h)
return checkpoint(layer_block)(x)
# Small network: checkpoint overhead not worth it
small_net = BenchmarkNet(n_layers=3, hidden_size=64)
x_small = jax.random.normal(jax.random.PRNGKey(0), (64,))
# Large network: checkpoint saves significant memory
large_net = BenchmarkNet(n_layers=20, hidden_size=512)
x_large = jax.random.normal(jax.random.PRNGKey(0), (512,))
print("Small network (3 layers, 64 hidden):")
print(" → Normal gradient: Fast, low memory")
print(" → Checkpoint: Overhead not justified\n")
print("Large network (20 layers, 512 hidden):")
print(" → Normal gradient: Stores ~20 activations (high memory)")
print(" → Checkpoint: Recomputes activations (saves memory)")
print(" → Recommended: Use checkpoint for deep/wide networks\n")
print("Rule of thumb:")
print(" Use checkpoint when: depth > 10 OR width > 256")
print(" Skip checkpoint when: shallow networks (< 5 layers)")
=== Example 5: When to use checkpoint ===
Small network (3 layers, 64 hidden):
→ Normal gradient: Fast, low memory
→ Checkpoint: Overhead not justified
Large network (20 layers, 512 hidden):
→ Normal gradient: Stores ~20 activations (high memory)
→ Checkpoint: Recomputes activations (saves memory)
→ Recommended: Use checkpoint for deep/wide networks
Rule of thumb:
Use checkpoint when: depth > 10 OR width > 256
Skip checkpoint when: shallow networks (< 5 layers)
2. make_jaxpr and StatefulFunction: Inspecting Compiled Code#
make_jaxpr converts a function into its JAX intermediate representation (Jaxpr), which reveals how JAX compiles and optimizes your code. StatefulFunction is the underlying mechanism that enables state-aware transformations.
What is Jaxpr?#
Jaxpr is JAX’s intermediate representation based on a simply-typed first-order lambda calculus with let-bindings. It shows:
Primitive operations (add, mul, sin, etc.)
Data dependencies
How state reads/writes are handled
Memory layout and optimizations
2.1 Basic Jaxpr Inspection#
# Example 1: Simple function jaxpr
print("=== Example 1: Basic jaxpr ===")
def simple_fn(x):
y = jnp.sin(x)
z = jnp.cos(y)
return z * 2
# Create jaxpr
jaxpr_fn = make_jaxpr(simple_fn)
jaxpr, states = jaxpr_fn(3.0)
print("Function: z = cos(sin(x)) * 2")
print("\nJaxpr representation:")
print(jaxpr)
print(f"\nStates used: {len(states)} (none for this simple function)")
=== Example 1: Basic jaxpr ===
Function: z = cos(sin(x)) * 2
Jaxpr representation:
{ lambda ; a:f32[]. let
b:f32[] = sin a
c:f32[] = cos b
d:f32[] = mul c 2.0:f32[]
in (d,) }
States used: 0 (none for this simple function)
2.2 Stateful Jaxpr: Tracking State Reads and Writes#
BrainState’s make_jaxpr reveals how states are accessed.
# Example 2: Jaxpr with states
print("\n=== Example 2: Stateful jaxpr ===")
# Create states
counter = brainstate.ShortTermState(jnp.array(0))
accumulator = brainstate.ShortTermState(jnp.array(0.0))
def stateful_fn(x):
# Read states
count = counter.value
accum = accumulator.value
# Update states
counter.value = count + 1
accumulator.value = accum + x
return accumulator.value / counter.value
# Inspect jaxpr
jaxpr_fn = make_jaxpr(stateful_fn)
jaxpr, states = jaxpr_fn(5.0)
print("Function: running average tracker")
print(f"\nStates accessed: {len(states)}")
for i, state in enumerate(states):
print(f" [{i}] {type(state).__name__}: {state.value}")
print("\nJaxpr (state operations visible):")
print(jaxpr)
print("\nNote: Jaxpr shows state reads as inputs, writes as outputs")
=== Example 2: Stateful jaxpr ===
Function: running average tracker
States accessed: 2
[0] ShortTermState: 0
[1] ShortTermState: 0.0
Jaxpr (state operations visible):
{ lambda ; a:f32[] b:i32[] c:f32[]. let
d:i32[] = add b 1:i32[]
e:f32[] = add c a
f:f32[] = convert_element_type[new_dtype=float32 weak_type=True] d
g:f32[] = div e f
in (g, d, e) }
Note: Jaxpr shows state reads as inputs, writes as outputs
2.3 Understanding StatefulFunction#
StatefulFunction is the core abstraction that enables all BrainState transformations. It:
Identifies states accessed during function execution
Compiles to Jaxpr with explicit state inputs/outputs
Manages state values before and after execution
Caches compilations for efficient repeated calls
# Example 3: Using StatefulFunction directly
print("\n=== Example 3: StatefulFunction mechanics ===")
# Create a module with state
class NeuralCell(brainstate.nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.W = brainstate.ParamState(jax.random.normal(
jax.random.PRNGKey(0), (input_size, hidden_size)
))
self.h = brainstate.ShortTermState(jnp.zeros(hidden_size))
def __call__(self, x):
# Update hidden state
self.h.value = jnp.tanh(x @ self.W.value + self.h.value)
return self.h.value
cell = NeuralCell(input_size=10, hidden_size=20)
# Wrap in StatefulFunction
sf = StatefulFunction(cell)
# Example input
x = jax.random.normal(jax.random.PRNGKey(1), (10,))
# Step 1: Compile and inspect
sf.make_jaxpr(x)
print("Step 1: Compilation")
print(f" Compiled for input shape: {x.shape}")
# Step 2: Get tracked states
states = sf.get_states(x)
read_states = sf.get_read_states(x)
write_states = sf.get_write_states(x)
print(f"\nStep 2: State identification")
print(f" Total states: {len(states)}")
print(f" Read states: {len(read_states)}")
for s in read_states:
print(f" - {type(s).__name__}: shape {s.value.shape}")
print(f" Write states: {len(write_states)}")
for s in write_states:
print(f" - {type(s).__name__}: shape {s.value.shape}")
# Step 3: Get jaxpr
jaxpr = sf.get_jaxpr(x)
print(f"\nStep 3: Jaxpr compilation")
print(f" Jaxpr variables: {len(jaxpr.jaxpr.invars)} inputs, {len(jaxpr.jaxpr.outvars)} outputs")
print(f" Jaxpr equations: {len(jaxpr.jaxpr.eqns)} operations")
# Step 4: Execute
output = sf(x)
print(f"\nStep 4: Execution")
print(f" Output shape: {output.shape}")
print(f" Hidden state updated: {cell.h.value.shape}")
=== Example 3: StatefulFunction mechanics ===
Step 1: Compilation
Compiled for input shape: (10,)
Step 2: State identification
Total states: 2
Read states: 1
- ParamState: shape (10, 20)
Write states: 1
- ShortTermState: shape (20,)
Step 3: Jaxpr compilation
Jaxpr variables: 3 inputs, 2 outputs
Jaxpr equations: 3 operations
Step 4: Execution
Output shape: (20,)
Hidden state updated: (20,)
2.4 Jaxpr for Gradient Computation#
Inspect how autodiff transforms your code.
# Example 4: Gradient jaxpr
print("\n=== Example 4: Gradient computation jaxpr ===")
# Simple loss function
params = brainstate.ParamState(jnp.array([1.0, 2.0, 3.0]))
def loss_fn(x):
return jnp.sum((params.value - x) ** 2)
# Original function jaxpr
print("Original function jaxpr:")
jaxpr_orig, _ = make_jaxpr(loss_fn)(jnp.array([0.5, 1.0, 1.5]))
print(jaxpr_orig)
# Gradient function jaxpr
print("\nGradient function jaxpr:")
grad_fn = brainstate.transform.grad(loss_fn, params)
jaxpr_grad, _ = make_jaxpr(grad_fn)(jnp.array([0.5, 1.0, 1.5]))
print(jaxpr_grad)
print("\nNote: Gradient jaxpr includes:")
print(" - Forward pass operations")
print(" - Backward pass (VJP) operations")
print(" - Much more complex than original")
=== Example 4: Gradient computation jaxpr ===
Original function jaxpr:
{ lambda ; a:f32[3] b:f32[3]. let
c:f32[3] = sub b a
d:f32[3] = integer_pow[y=2] c
e:f32[] = reduce_sum[axes=(0,) out_sharding=None] d
in (e, b) }
Gradient function jaxpr:
{ lambda ; a:f32[3] b:f32[3]. let
c:f32[3] = sub b a
d:f32[3] = integer_pow[y=2] c
e:f32[3] = integer_pow[y=1] c
f:f32[3] = mul 2.0:f32[] e
_:f32[] = reduce_sum[axes=(0,) out_sharding=None] d
g:f32[3] = broadcast_in_dim[
broadcast_dimensions=()
shape=(3,)
sharding=None
] 1.0:f32[]
h:f32[3] = mul g f
in (h, b) }
Note: Gradient jaxpr includes:
- Forward pass operations
- Backward pass (VJP) operations
- Much more complex than original
2.5 Jaxpr for Transformed Functions#
See how transformations affect the compiled code.
# Example 5: Transformation jaxpr
print("\n=== Example 5: Transformed function jaxpr ===")
def simple_fn(x):
return x ** 2
# Original
print("Original function:")
jaxpr1, _ = make_jaxpr(simple_fn)(jnp.array([1.0, 2.0, 3.0]))
print(jaxpr1)
# Vmapped version
print("\nVmapped function:")
vmapped_fn = brainstate.transform.vmap2(simple_fn)
jaxpr2, _ = make_jaxpr(vmapped_fn)(jnp.array([[1.0, 2.0], [3.0, 4.0]]))
print(jaxpr2)
print("\nNote: vmap adds batching dimensions to operations")
=== Example 5: Transformed function jaxpr ===
Original function:
{ lambda ; a:f32[3]. let b:f32[3] = integer_pow[y=2] a in (b,) }
Vmapped function:
{ lambda ; a:f32[2,2]. let
b:key<fry>[] = random_seed[impl=fry] 0:i32[]
c:u32[2] = random_unwrap b
d:key<fry>[] = random_wrap[impl=fry] c
e:key<fry>[2] = random_split[shape=(2,)] d
_:u32[2,2] = random_unwrap e
_:f32[2,2] = integer_pow[y=2] a
f:f32[2,2] = integer_pow[y=2] a
in (f,) }
Note: vmap adds batching dimensions to operations
2.6 StatefulFunction Caching#
StatefulFunction caches compiled jaxprs for efficiency.
# Example 6: Understanding compilation caching
print("\n=== Example 6: Compilation caching ===")
state = brainstate.ShortTermState(jnp.array(0.0))
def cached_fn(x):
state.value = state.value + jnp.sum(x)
return state.value
sf = StatefulFunction(cached_fn)
# First call: compile
sf.make_jaxpr(jnp.array([1.0, 2.0]))
stats1 = sf.get_cache_stats()
print("After first compilation:")
print(f" Jaxpr cache: {stats1['jaxpr_cache']}")
# Same shape: cache hit
sf.make_jaxpr(jnp.array([3.0, 4.0]))
stats2 = sf.get_cache_stats()
print("\nAfter same-shape call:")
print(f" Jaxpr cache: {stats2['jaxpr_cache']}")
print(f" Hit rate: {stats2['jaxpr_cache']['hit_rate']:.1f}%")
# Different shape: new compilation
sf.make_jaxpr(jnp.array([1.0, 2.0, 3.0]))
stats3 = sf.get_cache_stats()
print("\nAfter different-shape call:")
print(f" Jaxpr cache: {stats3['jaxpr_cache']}")
print(f" Cache size: {stats3['jaxpr_cache']['size']} entries")
print("\nCaching strategy:")
print(" - Different shapes → new compilation")
print(" - Same shapes → cache reuse")
print(" - Cache size limited to 128 entries (LRU)")
=== Example 6: Compilation caching ===
After first compilation:
Jaxpr cache: {'size': 1, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}
After same-shape call:
Jaxpr cache: {'size': 1, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}
Hit rate: 0.0%
After different-shape call:
Jaxpr cache: {'size': 2, 'maxsize': 128, 'hits': 0, 'misses': 0, 'hit_rate': 0.0}
Cache size: 2 entries
Caching strategy:
- Different shapes → new compilation
- Same shapes → cache reuse
- Cache size limited to 128 entries (LRU)
3. Debugging with jax.debug.print#
jax.debug.print enables runtime debugging in JIT-compiled code. Unlike regular print, it:
Executes during runtime (not tracing)
Works inside
@jit,vmap,grad, etc.Supports formatted output
Can print array values and shapes
Key principle: Debug prints happen at execution time, not trace time#
3.1 Basic Debug Printing#
# Example 1: Basic debug printing in JIT
print("=== Example 1: Debug printing in JIT ===")
@brainstate.transform.jit
def compute_with_debug(x):
jax.debug.print("Input: {x}", x=x)
y = x ** 2
jax.debug.print("After square: {y}", y=y)
z = jnp.sum(y)
jax.debug.print("Sum: {z}", z=z)
return z
result = compute_with_debug(jnp.array([1.0, 2.0, 3.0]))
print(f"\nFinal result: {result}")
print("\nNote: Debug prints appear during execution, not compilation")
=== Example 1: Debug printing in JIT ===
Input: [1. 2. 3.]
After square: [1. 4. 9.]
Sum: 14.0
Final result: 14.0
Note: Debug prints appear during execution, not compilation
3.2 Debugging State Updates#
# Example 2: Debugging stateful computations
print("\n=== Example 2: Debug state updates ===")
class DebuggableCell(brainstate.nn.Module):
def __init__(self, size):
super().__init__()
self.state = brainstate.ShortTermState(jnp.zeros(size))
self.weight = brainstate.ParamState(jax.random.normal(jax.random.PRNGKey(0), (size, size)))
def step(self, x):
jax.debug.print("Before update - state: {s}", s=self.state.value)
# Update
new_state = jnp.tanh(x @ self.weight.value + self.state.value)
jax.debug.print("Computed new state: {s}", s=new_state)
self.state.value = new_state
jax.debug.print("After update - state: {s}", s=self.state.value)
return new_state
cell = DebuggableCell(size=3)
@brainstate.transform.jit
def update_step(x):
return cell.step(x)
x = jnp.array([1.0, 0.0, -1.0])
output = update_step(x)
print(f"\nOutput: {output}")
=== Example 2: Debug state updates ===
Before update - state: [0. 0. 0.]
Computed new state: [ 0.97147846 0.9105761 -0.79975927]
After update - state: [ 0.97147846 0.9105761 -0.79975927]
Output: [ 0.97147846 0.9105761 -0.79975927]
3.3 Debugging Gradients#
# Example 3: Debug gradient computation
print("\n=== Example 3: Debug gradients ===")
param = brainstate.ParamState(jnp.array([2.0, 3.0]))
def loss_with_debug(x):
jax.debug.print("Forward - param: {p}, input: {x}", p=param.value, x=x)
pred = param.value * x
jax.debug.print("Forward - prediction: {pred}", pred=pred)
loss = jnp.sum(pred ** 2)
jax.debug.print("Forward - loss: {loss}", loss=loss)
return loss
# Gradient computation
x = jnp.array([0.5, 1.0])
grad_fn = brainstate.transform.grad(loss_with_debug, param)
print("\nComputing gradients:")
grads = grad_fn(x)
print(f"\nGradients: {grads}")
print("\nNote: Debug prints show forward pass values during gradient computation")
=== Example 3: Debug gradients ===
Computing gradients:
Forward - param: [2. 3.], input: [0.5 1. ]
Forward - prediction: [1. 3.]
Forward - loss: 10.0
Gradients: [1. 6.]
Note: Debug prints show forward pass values during gradient computation
3.4 Debugging Vectorized Code#
# Example 4: Debug vmap
print("\n=== Example 4: Debug vectorized code ===")
def process_item(x, index):
jax.debug.print("Processing item {i}: {x}", i=index, x=x)
return x ** 2
# Vmap over both arguments
vmapped_fn = brainstate.transform.vmap2(process_item)
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
indices = jnp.arange(len(batch_x))
print("\nProcessing batch:")
results = vmapped_fn(batch_x, indices)
print(f"\nResults: {results}")
print("\nNote: Debug prints execute for each element in the batch")
=== Example 4: Debug vectorized code ===
Processing batch:
Processing item 0: 1.0
Processing item 1: 2.0
Processing item 2: 3.0
Processing item 3: 4.0
Processing item 0: 1.0
Processing item 1: 2.0
Processing item 2: 3.0
Processing item 3: 4.0
Results: [ 1. 4. 9. 16.]
Note: Debug prints execute for each element in the batch
3.5 Conditional Debugging#
# Example 5: Conditional debug prints
print("\n=== Example 5: Conditional debugging ===")
iteration = brainstate.ShortTermState(jnp.array(0))
def training_step_with_debug(x, debug_every=5):
# Update iteration
iteration.value = iteration.value + 1
# Conditional debug print
jax.debug.print(
"Iteration {iter}: x={x}",
iter=iteration.value,
x=x,
)
loss = jnp.sum(x ** 2)
return loss
@brainstate.transform.jit
def train_step(x):
return training_step_with_debug(x, debug_every=3)
print("\nRunning 10 training steps:")
for i in range(10):
x = jax.random.normal(jax.random.PRNGKey(i), (5,))
loss = train_step(x)
print("\nNote: Debug prints only at iterations 3, 6, 9")
=== Example 5: Conditional debugging ===
Running 10 training steps:
Iteration 1: x=[ 1.6226422 2.0252647 -0.43359444 -0.07861735 0.1760909 ]
Iteration 2: x=[-0.15443718 0.08470728 -0.13598049 -0.15503626 1.2666674 ]
Iteration 3: x=[ 0.36057416 1.2849895 -0.73873436 1.1830745 -0.20641916]
Iteration 4: x=[-1.446257 1.539381 0.38250625 1.9707018 -0.5876674 ]
Iteration 5: x=[ 1.1777242 0.73848104 -1.0801564 0.3344669 0.00339968]
Iteration 6: x=[-0.08437306 1.4110229 0.63048154 -1.3100973 1.3689315 ]
Iteration 7: x=[ 0.3864717 -0.57079715 -1.678261 -1.203193 1.0770401 ]
Iteration 8: x=[ 0.45123515 1.9534509 -0.51623946 -0.1409403 0.6154967 ]
Iteration 9: x=[-0.55150557 -1.369112 2.7549403 0.5639917 -1.0112009 ]
Iteration 10: x=[-1.7417272 1.8461128 -0.20227258 -1.27005 -0.7593621 ]
Note: Debug prints only at iterations 3, 6, 9
3.6 Advanced: Custom Debug Callbacks#
# Example 6: Custom debugging with callbacks
print("\n=== Example 6: Custom debug callbacks ===")
def custom_debug_callback(name, value):
"""Custom callback for detailed debugging."""
print(f"[DEBUG {name}]:")
print(f" Shape: {value.shape}")
print(f" Dtype: {value.dtype}")
print(f" Min: {jnp.min(value):.4f}")
print(f" Max: {jnp.max(value):.4f}")
print(f" Mean: {jnp.mean(value):.4f}")
print(f" Std: {jnp.std(value):.4f}")
@brainstate.transform.jit
def compute_with_callback(x):
# Use debug callback for detailed inspection
jax.debug.callback(custom_debug_callback, "input", x)
y = jnp.tanh(x)
jax.debug.callback(custom_debug_callback, "after_tanh", y)
z = y @ y.T
jax.debug.callback(custom_debug_callback, "output", z)
return z
x = jax.random.normal(jax.random.PRNGKey(42), (5, 5))
print("\nExecuting with custom debug callbacks:")
result = compute_with_callback(x)
print(f"\nFinal result shape: {result.shape}")
=== Example 6: Custom debug callbacks ===
Executing with custom debug callbacks:
[DEBUG input]:
Shape: (5, 5)
Dtype: float32
Min: -1.9389
Max: 1.4458
Mean: 0.1084
Std: 0.8442
[DEBUG after_tanh]:
Shape: (5, 5)
Dtype: float32
Min: -0.9594
Max: 0.8949
Mean: 0.1294
Std: 0.5656
[DEBUG output]:
Shape: (5, 5)
Dtype: float32
Min: -1.2547
Max: 2.2312
Mean: 0.1165
Std: 0.9866
Final result shape: (5, 5)
Summary#
This tutorial covered three essential BrainState utilities:
1. checkpoint: Memory-Efficient Gradients#
Purpose: Reduce memory usage during gradient computation
Mechanism: Recompute activations during backward pass instead of storing them
Tradeoff: ~2x computation for significant memory savings (O(√n) vs O(n))
When to use: Deep networks (>10 layers), wide networks (>256 hidden), long sequences
Advanced: Custom policies control what to save vs. recompute
State-aware: Works seamlessly with BrainState’s
Stateobjects
2. make_jaxpr and StatefulFunction: Code Inspection#
Purpose: Understand how JAX compiles and optimizes your code
Jaxpr: JAX’s intermediate representation showing primitive operations and data flow
StatefulFunction: Core mechanism enabling all BrainState transformations
Identifies state reads and writes
Compiles to Jaxpr with explicit state handling
Caches compilations for efficiency (LRU cache, 128 entries)
Manages state values automatically
Use cases: Debugging compilation issues, understanding transformations, optimization analysis
3. jax.debug.print: Runtime Debugging#
Purpose: Debug JIT-compiled code during execution
Key features:
Prints at runtime (not trace time)
Works inside
@jit,vmap,grad, etc.Supports formatted output and array inspection
Best practices:
Use debug flags to enable/disable
Print statistics, not full arrays
Check for NaN/Inf in critical ops
Use callbacks for complex debugging
Disable in production
Integration with BrainState#
All three tools are state-aware:
checkpointpreserves state semantics during rematerializationmake_jaxprreveals state reads/writes in compiled codejax.debug.printcan inspect state values during execution
These utilities are essential for developing, optimizing, and debugging complex stateful models in BrainState.