Loops and Conditionals#
This tutorial covers state-aware control flow primitives in brainstate.transform. These APIs provide JAX-compatible loops and conditionals while safely handling State objects.
We’ll explore three categories of control flow:
Loop Transformations:
scan,checkpointed_scan,for_loop,checkpointed_for_loopWhile Loops:
while_loop,bounded_while_loopConditional Control Flow:
cond,switch,ifelse
Each API is designed to work seamlessly with BrainState’s state management system while maintaining JAX’s functional programming paradigm.
Imports and Setup#
import jax
import jax.numpy as jnp
import brainstate
from brainstate.transform import (
scan,
checkpointed_scan,
for_loop,
checkpointed_for_loop,
while_loop,
bounded_while_loop,
cond,
switch,
ifelse,
)
# Import ProgressBar
from brainstate.transform import ProgressBar
1. Loop Transformations#
Loop transformations provide efficient iteration over sequences with state tracking. They compile to a single JAX primitive, reducing compilation overhead.
1.1 scan: Stateful Scanning with Carry#
scan is the fundamental loop primitive that:
Iterates over a sequence along the leading axis
Maintains a “carry” value that threads through iterations
Collects outputs at each step
Properly handles
Stateobjects
Function signature:
scan(
f: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
pbar: ProgressBar | int | None = None,
) -> Tuple[Carry, Y]
Parameters:
f: Function of type(carry, x) -> (new_carry, output)init: Initial carry valuexs: Sequence to iterate over (along axis 0)length: Optional iteration count (inferred fromxsif not provided)reverse: If True, iterate in reverse orderunroll: Number of iterations to unroll (1=no unrolling, True=full unrolling)pbar: Optional progress bar
# Example 1: Basic scan with carry
def cumsum_body(carry, x):
"""Accumulate sum and return both new carry and current sum."""
new_carry = carry + x
return new_carry, new_carry
xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
final_sum, cumulative_sums = scan(cumsum_body, init=0.0, xs=xs)
print("Input sequence:", xs)
print("Final sum:", final_sum)
print("Cumulative sums:", cumulative_sums)
Input sequence: [1. 2. 3. 4. 5.]
Final sum: 15.0
Cumulative sums: [ 1. 3. 6. 10. 15.]
# Example 2: Scan with stateful computation
class RunningStats(brainstate.nn.Module):
"""Maintain running mean and variance."""
def __init__(self):
super().__init__()
self.count = brainstate.ShortTermState(jnp.array(0))
self.mean = brainstate.ShortTermState(jnp.array(0.0))
self.m2 = brainstate.ShortTermState(jnp.array(0.0)) # sum of squared differences
def update(self, x):
"""Update statistics with new value using Welford's algorithm."""
self.count.value = self.count.value + 1
delta = x - self.mean.value
self.mean.value = self.mean.value + delta / self.count.value
delta2 = x - self.mean.value
self.m2.value = self.m2.value + delta * delta2
variance = self.m2.value / self.count.value
return {'mean': self.mean.value, 'var': variance}
stats = RunningStats()
def stats_body(carry, x):
result = stats.update(x)
return carry, result
data = jnp.array([2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0])
_, history = scan(stats_body, init=None, xs=data)
print("Data:", data)
print("\nRunning mean:", history['mean'])
print("Running variance:", history['var'])
print("\nFinal statistics:")
print(f" Count: {stats.count.value}")
print(f" Mean: {stats.mean.value}")
print(f" Variance: {stats.m2.value / stats.count.value}")
Data: [2. 4. 4. 4. 5. 5. 7. 9.]
Running mean: [2. 3. 3.3333333 3.5 3.8 4. 4.428571
5. ]
Running variance: [0. 1. 0.8888889 0.75 0.96000004 1.
1.9591838 4. ]
Final statistics:
Count: 8
Mean: 5.0
Variance: 4.0
# Example 3: Reverse scan
def reverse_cumsum(carry, x):
new_carry = carry + x
return new_carry, new_carry
xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
_, forward_sums = scan(reverse_cumsum, 0.0, xs, reverse=False)
_, backward_sums = scan(reverse_cumsum, 0.0, xs, reverse=True)
print("Input:", xs)
print("Forward cumsum:", forward_sums)
print("Backward cumsum:", backward_sums)
Input: [1. 2. 3. 4. 5.]
Forward cumsum: [ 1. 3. 6. 10. 15.]
Backward cumsum: [15. 14. 12. 9. 5.]
Progress Bar with scan#
The pbar parameter enables progress tracking during long-running scans. You can:
Pass a
ProgressBarinstance for full control over display optionsPass an integer for quick setup (updates every N iterations)
Customize the description with static or dynamic messages
# Example 4: Progress bar with scan - simple integer freq
print("\n=== Simple progress bar (update every 20 iterations) ===")
def expensive_computation(carry, x):
"""Simulate expensive computation."""
# Some computation
result = carry + jnp.sin(x) * jnp.cos(x)
return result, result
# Create long sequence
long_sequence = jnp.linspace(0, 10 * jnp.pi, 100)
# Use integer for simple progress bar (updates every 20 iterations)
final, outputs = scan(expensive_computation, init=0.0, xs=long_sequence, pbar=20)
print(f"\nFinal result: {final}")
=== Simple progress bar (update every 20 iterations) ===
Final result: -4.0076361074170563e-07
# Example 5: Progress bar with custom ProgressBar instance
print("\n=== Custom progress bar with ProgressBar ===")
# Create ProgressBar with custom settings
pbar = ProgressBar(freq=10, desc="Processing sequence")
final, outputs = scan(expensive_computation, init=0.0, xs=long_sequence, pbar=pbar)
print(f"\nCompleted! Final result: {final}")
=== Custom progress bar with ProgressBar ===
Completed! Final result: -4.0076361074170563e-07
# Example 6: Dynamic progress bar description based on loop state
print("\n=== Dynamic progress bar with loop state ===")
class OptimizationTracker(brainstate.nn.Module):
"""Track optimization progress."""
def __init__(self):
super().__init__()
self.best_loss = brainstate.ShortTermState(jnp.array(float('inf')))
def step(self, params, x):
# Compute loss
loss = jnp.sum((params - x) ** 2)
# Update best
self.best_loss.value = jnp.minimum(self.best_loss.value, loss)
# Update parameters
new_params = params - 0.1 * 2 * (params - x)
return new_params, loss
tracker = OptimizationTracker()
def scan_body_with_tracking(params, x):
return tracker.step(params, x)
# Define dynamic description
def format_progress(data):
"""Format progress with current loss and best loss."""
return {
"iter": data["i"],
"loss": data["y"],
"best": tracker.best_loss.value
}
pbar_dynamic = ProgressBar(
freq=15,
desc=("Iter {iter:3d} | Loss: {loss:.4f} | Best: {best:.4f}", format_progress)
)
targets = jax.random.normal(jax.random.PRNGKey(42), (100,))
init_params = jnp.array(0.0)
final_params, loss_history = scan(
scan_body_with_tracking,
init=init_params,
xs=targets,
pbar=pbar_dynamic
)
print(f"\nOptimization completed!")
print(f"Final parameters: {final_params}")
print(f"Final loss: {loss_history[-1]}")
print(f"Best loss achieved: {tracker.best_loss.value}")
=== Dynamic progress bar with loop state ===
Optimization completed!
Final parameters: -0.04794257506728172
Final loss: 1.41942298412323
Best loss achieved: 3.858334093820304e-05
1.2 checkpointed_scan: Memory-Efficient Scanning#
checkpointed_scan is a memory-optimized version of scan that uses gradient checkpointing. This is crucial for:
Long sequences where storing all intermediate activations is memory-prohibitive during gradient computation
Trading computation time for memory during backpropagation
Memory efficiency is achieved by only storing checkpoints at regular intervals during the forward pass, then recomputing intermediate values during the backward pass when needed
Function signature:
checkpointed_scan(
f: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
base: int = 16,
pbar: Optional[ProgressBar | int] = None,
) -> Tuple[Carry, Y]
Key parameter:
base: Checkpointing base (default=16). Smaller values save more memory but increase recomputation during backward pass. The implementation uses a hierarchical checkpointing scheme wheremax_steps = base^kfor some k.
Memory savings during gradient computation:
Regular
scan: Stores all intermediate activations → O(n) memory for sequence length ncheckpointed_scan: Stores only checkpoints → O(log_base(n)) memoryDuring backward pass: Recomputes intermediate values between checkpoints as needed
# Example: Memory-efficient scan for gradient computation
class RecurrentCell(brainstate.nn.Module):
"""Simple recurrent cell with hidden state."""
def __init__(self, hidden_size):
super().__init__()
self.hidden_size = hidden_size
self.weight = brainstate.ParamState(jax.random.normal(
jax.random.PRNGKey(0), (hidden_size, hidden_size)
))
def step(self, hidden, x):
"""Single recurrent step."""
new_hidden = jnp.tanh(jnp.dot(self.weight.value, hidden) + x)
return new_hidden
# Create a cell and input sequence
cell = RecurrentCell(hidden_size=32)
sequence_length = 100
inputs = jax.random.normal(jax.random.PRNGKey(1), (sequence_length, 32))
def rnn_body(hidden, x):
new_hidden = cell.step(hidden, x)
return new_hidden, new_hidden
# Use checkpointed scan for memory efficiency during gradient computation
init_hidden = jnp.zeros(32)
final_hidden, all_hiddens = checkpointed_scan(
rnn_body,
init=init_hidden,
xs=inputs,
base=8 # Checkpoint every 8 steps
)
print(f"Sequence length: {sequence_length}")
print(f"Hidden size: {cell.hidden_size}")
print(f"Final hidden shape: {final_hidden.shape}")
print(f"All hiddens shape: {all_hiddens.shape}")
print(f"\nCheckpointing configuration:")
print(f" Base: 8 (stores checkpoint every 8 steps)")
print(f" Memory saved: Stores ~{sequence_length // 8} checkpoints instead of {sequence_length} activations")
print(f" During backprop: Recomputes activations between checkpoints as needed")
Sequence length: 100
Hidden size: 32
Final hidden shape: (32,)
All hiddens shape: (100, 32)
Checkpointing configuration:
Base: 8 (stores checkpoint every 8 steps)
Memory saved: Stores ~12 checkpoints instead of 100 activations
During backprop: Recomputes activations between checkpoints as needed
Progress Bar with checkpointed_scan#
checkpointed_scan also supports progress bars, which is especially useful for very long sequences where you want to monitor progress.
# Example: Progress bar with checkpointed_scan
print("\n=== Checkpointed scan with progress bar ===")
class LongRunningComputation(brainstate.nn.Module):
"""Simulate a long-running computation."""
def __init__(self):
super().__init__()
self.total_ops = brainstate.ShortTermState(jnp.array(0))
def process(self, state, x):
self.total_ops.value = self.total_ops.value + 1
# Some computation
new_state = state + jnp.tanh(x)
output = jnp.sin(new_state) * jnp.cos(x)
return new_state, output
long_comp = LongRunningComputation()
def body(state, x):
return long_comp.process(state, x)
# Long sequence
very_long_sequence = jnp.linspace(0, 20 * jnp.pi, 500)
# Progress bar that updates every 50 iterations
pbar_checkpointed = ProgressBar(
freq=50,
desc="Checkpointed scan progress"
)
final_state, results = checkpointed_scan(
body,
init=0.0,
xs=very_long_sequence,
base=10,
pbar=pbar_checkpointed
)
print(f"\nProcessed {long_comp.total_ops.value} operations")
print(f"Final state: {final_state}")
print(f"Results shape: {results.shape}")
=== Checkpointed scan with progress bar ===
Processed 500 operations
Final state: 493.9846496582031
Results shape: (500,)
1.3 for_loop: Simplified Loop Without Carry#
for_loop provides a simpler interface when you don’t need an explicit carry value. It:
Accepts variadic arguments that are sliced along axis 0
Collects and returns outputs from each iteration - the return value from your function at each timestep is saved and stacked into the final output array
Internally uses
scanwithNoneas the carry
Function signature:
for_loop(
f: Callable[..., Y],
*xs,
length: Optional[int] = None,
reverse: bool = False,
unroll: int | bool = 1,
pbar: Optional[ProgressBar | int] = None
) -> Y
Key differences from scan:
Function signature is
(*xs) -> outputinstead of(carry, x) -> (carry, output)No carry value to manage
Important: The return value at each iteration is collected and stacked along axis 0 to form the final output. This means if your function returns a scalar at each step,
for_loopreturns a 1D array; if it returns a vector of shape(d,), the output will be shape(n, d)wherenis the number of iterations.
# Example 1: Understanding output collection in for_loop
def compute(x, y, z):
"""Combine three inputs."""
return x * y + z
xs = jnp.array([1.0, 2.0, 3.0, 4.0])
ys = jnp.array([2.0, 3.0, 4.0, 5.0])
zs = jnp.array([0.5, 1.0, 1.5, 2.0])
# for_loop collects the output from EACH iteration
results = for_loop(compute, xs, ys, zs)
print("x:", xs)
print("y:", ys)
print("z:", zs)
print("x * y + z:", results)
print(f"\nNotice: for_loop collected {len(results)} outputs (one per iteration)")
print(f"Each element results[i] = xs[i] * ys[i] + zs[i]")
print(f"Output shape: {results.shape} (stacked along axis 0)")
x: [1. 2. 3. 4.]
y: [2. 3. 4. 5.]
z: [0.5 1. 1.5 2. ]
x * y + z: [ 2.5 7. 13.5 22. ]
Notice: for_loop collected 4 outputs (one per iteration)
Each element results[i] = xs[i] * ys[i] + zs[i]
Output shape: (4,) (stacked along axis 0)
# Example 2: Stateful for_loop
class Accumulator(brainstate.nn.Module):
"""Simple accumulator that tracks total and count."""
def __init__(self):
super().__init__()
self.total = brainstate.ShortTermState(jnp.array(0.0))
self.count = brainstate.ShortTermState(jnp.array(0))
def process(self, x):
self.total.value = self.total.value + x
self.count.value = self.count.value + 1
return self.total.value / self.count.value # running average
acc = Accumulator()
data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
running_averages = for_loop(acc.process, data)
print("Data:", data)
print("Running averages:", running_averages)
print(f"\nFinal state: total={acc.total.value}, count={acc.count.value}")
print(f"Final average: {acc.total.value / acc.count.value}")
Data: [1. 2. 3. 4. 5. 6.]
Running averages: [1. 1.5 2. 2.5 3. 3.5]
Final state: total=21.0, count=6
Final average: 3.5
Progress Bar with for_loop#
for_loop also supports progress bars. This is particularly useful when processing large batches of data.
# Example 3: Progress bar with for_loop - simple case
print("\n=== For loop with progress bar ===")
class DataProcessor(brainstate.nn.Module):
"""Process data with progress tracking."""
def __init__(self):
super().__init__()
self.processed_count = brainstate.ShortTermState(jnp.array(0))
self.sum_val = brainstate.ShortTermState(jnp.array(0.0))
def process_item(self, x):
self.processed_count.value = self.processed_count.value + 1
self.sum_val.value = self.sum_val.value + x
# Simulate some processing
result = jnp.exp(x) / (1 + jnp.exp(x)) # sigmoid
return result
processor = DataProcessor()
# Create dataset
dataset = jax.random.normal(jax.random.PRNGKey(123), (200,))
# Use simple integer for progress updates
processed = for_loop(processor.process_item, dataset, pbar=25)
print(f"\nProcessed {processor.processed_count.value} items")
print(f"Sum of inputs: {processor.sum_val.value}")
print(f"Processed data shape: {processed.shape}")
=== For loop with progress bar ===
Processed 200 items
Sum of inputs: 9.437446594238281
Processed data shape: (200,)
# Example 4: For loop with dynamic progress description
print("\n=== For loop with dynamic progress description ===")
class BatchProcessor(brainstate.nn.Module):
"""Process batches with statistics."""
def __init__(self):
super().__init__()
self.mean = brainstate.ShortTermState(jnp.array(0.0))
self.variance = brainstate.ShortTermState(jnp.array(0.0))
self.count = brainstate.ShortTermState(jnp.array(0))
def update(self, x):
self.count.value = self.count.value + 1
delta = x - self.mean.value
self.mean.value = self.mean.value + delta / self.count.value
self.variance.value = self.variance.value + delta * (x - self.mean.value)
return x ** 2
batch_proc = BatchProcessor()
def format_batch_progress(data):
"""Show current statistics."""
return {
"n": data["i"],
"mean": batch_proc.mean.value,
"var": batch_proc.variance.value / jnp.maximum(batch_proc.count.value, 1)
}
pbar_batch = ProgressBar(
freq=20,
desc=("Batch {n:3d} | Mean: {mean:+.3f} | Var: {var:.3f}", format_batch_progress)
)
batch_data = jax.random.normal(jax.random.PRNGKey(456), (150,)) * 2.0 + 1.0
squared = for_loop(batch_proc.update, batch_data, pbar=pbar_batch)
print(f"\nFinal statistics:")
print(f" Mean: {batch_proc.mean.value}")
print(f" Variance: {batch_proc.variance.value / batch_proc.count.value}")
print(f" Count: {batch_proc.count.value}")
=== For loop with dynamic progress description ===
Final statistics:
Mean: 0.9223809242248535
Variance: 3.488231897354126
Count: 150
1.4 checkpointed_for_loop: Memory-Efficient For Loop#
The checkpointed version of for_loop combines the simplicity of for_loop with the memory efficiency of checkpointing during gradient computation.
Function signature:
checkpointed_for_loop(
f: Callable[..., Y],
*xs: X,
length: Optional[int] = None,
base: int = 16,
pbar: Optional[ProgressBar | int] = None,
) -> Y
Memory efficiency during gradient computation:
Like
checkpointed_scan, this variant significantly reduces memory usage during backpropagationEssential for training models with very long sequences where storing all intermediate activations would cause out-of-memory errors
The
baseparameter controls the memory/computation tradeoff: smaller base = less memory but more recomputation
# Example: Processing long sequence with state
class ExpMovingAverage(brainstate.nn.Module):
"""Exponential moving average."""
def __init__(self, alpha=0.1):
super().__init__()
self.alpha = alpha
self.ema = brainstate.ShortTermState(jnp.array(0.0))
self.initialized = brainstate.ShortTermState(jnp.array(False))
def update(self, x):
# Initialize with first value
self.ema.value = jnp.where(
self.initialized.value,
self.alpha * x + (1 - self.alpha) * self.ema.value,
x
)
self.initialized.value = True
return self.ema.value
ema = ExpMovingAverage(alpha=0.3)
# Generate noisy signal
signal = jnp.sin(jnp.linspace(0, 4 * jnp.pi, 200)) + 0.2 * brainstate.random.normal(size=(200,))
# Process with checkpointed for_loop
smoothed = checkpointed_for_loop(ema.update, signal, base=10)
print(f"Signal length: {len(signal)}")
print(f"Smoothed signal shape: {smoothed.shape}")
print(f"Original signal range: [{signal.min():.3f}, {signal.max():.3f}]")
print(f"Smoothed signal range: [{smoothed.min():.3f}, {smoothed.max():.3f}]")
Signal length: 200
Smoothed signal shape: (200,)
Original signal range: [-1.401, 1.408]
Smoothed signal range: [-1.131, 1.200]
Progress Bar with checkpointed_for_loop#
checkpointed_for_loop supports progress bars to help track processing of very long sequences.
# Example: Progress bar with checkpointed_for_loop
print("\n=== Checkpointed for loop with progress bar ===")
class StreamProcessor(brainstate.nn.Module):
"""Process streaming data."""
def __init__(self, momentum=0.9):
super().__init__()
self.momentum = momentum
self.running_avg = brainstate.ShortTermState(jnp.array(0.0))
def process(self, x):
# Update exponential moving average
self.running_avg.value = (
self.momentum * self.running_avg.value + (1 - self.momentum) * x
)
return self.running_avg.value
stream_proc = StreamProcessor(momentum=0.95)
# Generate long data stream
data_stream = jax.random.normal(jax.random.PRNGKey(789), (1000,))
# Progress bar with count parameter (updates exactly 10 times)
pbar_stream = ProgressBar(
count=10,
desc="Processing data stream"
)
smoothed_stream = checkpointed_for_loop(
stream_proc.process,
data_stream,
base=20,
pbar=pbar_stream
)
print(f"\nStream processed!")
print(f"Final running average: {stream_proc.running_avg.value}")
print(f"Smoothed stream shape: {smoothed_stream.shape}")
print(f"First 5 values: {smoothed_stream[:5]}")
print(f"Last 5 values: {smoothed_stream[-5:]}")
=== Checkpointed for loop with progress bar ===
Stream processed!
Final running average: 0.050981856882572174
Smoothed stream shape: (1000,)
First 5 values: [ 0.03755957 -0.01195641 0.0011875 0.03787947 0.04933156]
Last 5 values: [ 0.10296391 0.07546234 -0.00500105 0.06068815 0.05098186]
1.5 Comparison: scan vs for_loop#
When to use each:
Use scan when:
You need to thread a carry value through iterations
Implementing recurrent patterns (RNNs, state machines)
You want explicit control over the accumulator
Use for_loop when:
No carry value is needed
Processing independent items with side effects (state updates)
Simpler, more Pythonic syntax is preferred
# Comparison example: Computing powers of 2
# Using scan: carry explicitly tracks the power
def scan_version(n):
def body(carry, _):
return carry * 2, carry
_, powers = scan(body, init=1, xs=jnp.arange(n))
return powers
# Using for_loop with state: state tracks the power
class PowerTracker(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.current = brainstate.ShortTermState(jnp.array(1))
def next_power(self, _):
result = self.current.value
self.current.value = self.current.value * 2
return result
def forloop_version(n):
tracker = PowerTracker()
return for_loop(tracker.next_power, jnp.arange(n))
n = 10
print(f"Powers of 2 (first {n} values):")
print("scan result: ", scan_version(n))
print("for_loop result:", forloop_version(n))
Powers of 2 (first 10 values):
scan result: [ 1 2 4 8 16 32 64 128 256 512]
for_loop result: [ 1 2 4 8 16 32 64 128 256 512]
2. While Loops#
While loops provide conditional iteration where the number of iterations is not known in advance.
2.1 while_loop: Dynamic Conditional Iteration#
while_loop executes a body function repeatedly while a condition remains true. This is the stateful version of jax.lax.while_loop.
Function signature:
while_loop(
cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T
) -> T
Parameters:
cond_fun: Function that returns True to continue loopingbody_fun: Function that updates the loop valueinit_val: Initial loop value
Important constraints:
cond_funcannot modify state (read-only)Loop value must maintain fixed shape and dtype
Not reverse-mode differentiable (use
bounded_while_loopinstead)
# Example 1: Simple while loop - find first power of 2 above threshold
def find_power_of_2_above(threshold):
def cond_fn(val):
return val < threshold
def body(val):
return val * 2
return while_loop(cond_fn, body, init_val=1)
threshold = 1000
result = find_power_of_2_above(threshold)
print(f"First power of 2 above {threshold}: {result}")
First power of 2 above 1000: 1024
# Example 2: Stateful while loop - iterative refinement
class IterativeRefiner(brainstate.nn.Module):
"""Iteratively refine an estimate using Newton's method."""
def __init__(self, target):
super().__init__()
self.target = target
self.iterations = brainstate.ShortTermState(jnp.array(0))
def refine(self, x):
"""Newton's method step for computing sqrt(target)."""
self.iterations.value = self.iterations.value + 1
return 0.5 * (x + self.target / x)
# Compute square root of 2 using Newton's method
refiner = IterativeRefiner(target=2.0)
def cond_f(x):
# Continue until error is small enough
return jnp.abs(x * x - refiner.target) > 1e-6
def body(x):
return refiner.refine(x)
result = while_loop(cond_f, body, init_val=1.0)
print(f"Computing sqrt(2)...")
print(f"Result: {result}")
print(f"Actual sqrt(2): {jnp.sqrt(2.0)}")
print(f"Error: {jnp.abs(result - jnp.sqrt(2.0))}")
print(f"Iterations: {refiner.iterations.value}")
Computing sqrt(2)...
Result: 1.4142135381698608
Actual sqrt(2): 1.4142135381698608
Error: 0.0
Iterations: 4
# Example 3: Complex loop value (pytree)
class Collatz(brainstate.nn.Module):
"""Track Collatz sequence statistics."""
def __init__(self):
super().__init__()
self.max_value = brainstate.ShortTermState(jnp.array(0))
def step(self, n):
self.max_value.value = jnp.maximum(self.max_value.value, n)
return jnp.where(n % 2 == 0, n // 2, 3 * n + 1)
collatz = Collatz()
def collatz_cond(state):
n, steps = state
return n > 1
def collatz_body(state):
n, steps = state
return collatz.step(n), steps + 1
start_value = 27
final_n, total_steps = while_loop(
collatz_cond,
collatz_body,
init_val=(start_value, 0)
)
print(f"Collatz sequence starting from {start_value}:")
print(f" Converged to: {final_n}")
print(f" Steps taken: {total_steps}")
print(f" Maximum value reached: {collatz.max_value.value}")
Collatz sequence starting from 27:
Converged to: 1
Steps taken: 111
Maximum value reached: 9232
2.2 bounded_while_loop: While Loop with Maximum Steps#
bounded_while_loop adds a maximum iteration limit to while loops. This is important for:
Preventing infinite loops
Enabling reverse-mode differentiation (unlike
while_loop)Providing compilation time guarantees
Function signature:
bounded_while_loop(
cond_fun: Callable[[T], BooleanNumeric],
body_fun: Callable[[T], T],
init_val: T,
*,
max_steps: int,
base: int = 16,
)
Key parameters:
max_steps: Maximum number of iterations before terminationbase: Compilation/runtime tradeoff (default=16)Larger base = faster compilation, slightly slower runtime
Smaller base = slower compilation, faster runtime
Compile time scales with
math.ceil(math.log(max_steps, base))
# Example 1: Gradient descent with bounded iterations
class GradientDescent(brainstate.nn.Module):
"""Simple gradient descent optimizer."""
def __init__(self, learning_rate=0.1):
super().__init__()
self.lr = learning_rate
self.steps = brainstate.ShortTermState(jnp.array(0))
def step(self, x):
# Gradient of f(x) = (x - 3)^2
grad = 2 * (x - 3.0)
self.steps.value = self.steps.value + 1
return x - self.lr * grad
optimizer = GradientDescent(learning_rate=0.1)
def converged(x):
# Continue if far from optimum
return jnp.abs(x - 3.0) > 1e-4
result = bounded_while_loop(
converged,
optimizer.step,
init_val=0.0,
max_steps=100
)
print(f"Minimizing f(x) = (x - 3)^2")
print(f"Starting from x = 0.0")
print(f"Final x: {result}")
print(f"Target x: 3.0")
print(f"Error: {jnp.abs(result - 3.0)}")
print(f"Iterations used: {optimizer.steps.value} / 100")
Minimizing f(x) = (x - 3)^2
Starting from x = 0.0
Final x: 3.001088857650757
Target x: 3.0
Error: 0.001088857650756836
Iterations used: 252 / 100
# Example 2: Comparing different base values
class Counter(brainstate.nn.Module):
def __init__(self):
super().__init__()
self.count = brainstate.ShortTermState(jnp.array(0))
def increment(self, x):
self.count.value = self.count.value + 1
return x + 1
def compare_base_values():
max_steps = 100
for base in [2, 8, 16]:
counter = Counter()
result = bounded_while_loop(
lambda x: x < 50,
counter.increment,
init_val=0,
max_steps=max_steps,
base=base
)
recursion_depth = jnp.ceil(jnp.log(max_steps) / jnp.log(base))
print(f"Base {base:2d}: result={result}, iterations={counter.count.value}, "
f"recursion_depth≈{int(recursion_depth)}")
compare_base_values()
Base 2: result=206, iterations=50, recursion_depth≈7
Base 8: result=3746, iterations=50, recursion_depth≈3
Base 16: result=3346, iterations=50, recursion_depth≈2
# Example 3: Differentiable bounded_while_loop
def smooth_threshold(x, threshold=5.0, lr=0.5, max_steps=20):
"""Smoothly approach threshold using gradient descent."""
def cond_fn(val):
return val < threshold - 0.1
def body(val):
# Gradient of loss = (val - threshold)^2
grad = 2 * (val - threshold)
return val - lr * grad
return bounded_while_loop(cond_fn, body, x, max_steps=max_steps)
# Compute gradient
x = 0.0
value, grad = jax.value_and_grad(smooth_threshold)(x)
print(f"Input: {x}")
print(f"Output: {value}")
print(f"Gradient: {grad}")
print(f"\nbounded_while_loop is differentiable!")
Input: 0.0
Output: 4085.0
Gradient: 0.0
bounded_while_loop is differentiable!
2.3 Comparison: while_loop vs bounded_while_loop#
Use while_loop when:
Number of iterations is truly unknown
Not computing gradients
Want standard JAX while loop semantics
Use bounded_while_loop when:
Need gradient computation
Want safety against infinite loops
Can provide reasonable upper bound on iterations
Need predictable compilation characteristics
3. Conditional Control Flow#
Conditional primitives enable branching logic that compiles efficiently and handles state properly.
3.1 cond: Binary Conditional (If/Else)#
cond selectively executes one of two branches based on a boolean predicate. This is the stateful version of jax.lax.cond.
Function signature:
cond(
pred,
true_fun: Callable,
false_fun: Callable,
*operands
)
Parameters:
pred: Boolean scalar (or numeric, where non-zero is True)true_fun: Function called whenpredis Truefalse_fun: Function called whenpredis False*operands: Arguments passed to the selected function
Key properties:
Only the selected branch is executed (lazy evaluation)
Both branches must return the same pytree structure
State modifications in branches are properly tracked
# Example 1: Simple conditional
def positive_branch(x):
return x ** 2
def negative_branch(x):
return -x
for value in [-5.0, 3.0, 0.0]:
result = cond(value >= 0, positive_branch, negative_branch, value)
print(f"cond({value} >= 0): {result}")
cond(-5.0 >= 0): 5.0
cond(3.0 >= 0): 9.0
cond(0.0 >= 0): 0.0
# Example 2: Stateful conditional
class BranchTracker(brainstate.nn.Module):
"""Track which branches were taken."""
def __init__(self):
super().__init__()
self.true_count = brainstate.ShortTermState(jnp.array(0))
self.false_count = brainstate.ShortTermState(jnp.array(0))
def true_branch(self, x):
self.true_count.value = self.true_count.value + 1
return x * 2
def false_branch(self, x):
self.false_count.value = self.false_count.value + 1
return x / 2
tracker = BranchTracker()
# Test multiple values
values = jnp.array([1.0, -2.0, 3.0, -4.0, 5.0])
results = []
for v in values:
result = cond(v > 0, tracker.true_branch, tracker.false_branch, v)
results.append(result)
print("Values:", values)
print("Results:", jnp.array(results))
print(f"\nBranch statistics:")
print(f" True branch taken: {tracker.true_count.value} times")
print(f" False branch taken: {tracker.false_count.value} times")
Values: [ 1. -2. 3. -4. 5.]
Results: [ 2. -1. 6. -2. 10.]
Branch statistics:
True branch taken: 3 times
False branch taken: 2 times
# Example 3: Nested conditionals
class Classifier(brainstate.nn.Module):
"""Classify numbers into categories."""
def __init__(self):
super().__init__()
self.classification_counts = brainstate.ShortTermState({
'large_positive': jnp.array(0),
'small_positive': jnp.array(0),
'small_negative': jnp.array(0),
'large_negative': jnp.array(0),
})
def classify_positive(self, x):
def large(x):
counts = self.classification_counts.value
counts['large_positive'] = counts['large_positive'] + 1
self.classification_counts.value = counts
return 'large_positive'
def small(x):
counts = self.classification_counts.value
counts['small_positive'] = counts['small_positive'] + 1
self.classification_counts.value = counts
return 'small_positive'
return cond(x > 5.0, large, small, x)
def classify_negative(self, x):
def small(x):
counts = self.classification_counts.value
counts['small_negative'] = counts['small_negative'] + 1
self.classification_counts.value = counts
return 'small_negative'
def large(x):
counts = self.classification_counts.value
counts['large_negative'] = counts['large_negative'] + 1
self.classification_counts.value = counts
return 'large_negative'
return cond(x > -5.0, small, large, x)
def classify(self, x):
return cond(x >= 0, self.classify_positive, self.classify_negative, x)
classifier = Classifier()
with jax.disable_jit():
test_values = jnp.array([10.0, 2.0, -3.0, -8.0, 7.0, -1.0])
classifications = [classifier.classify(v) for v in test_values]
print("Values:", test_values)
print("Classifications:", classifications)
print("\nCategory counts:")
for category, count in classifier.classification_counts.value.items():
print(f" {category}: {count}")
Values: [10. 2. -3. -8. 7. -1.]
Classifications: ['large_positive', 'small_positive', 'small_negative', 'large_negative', 'large_positive', 'small_negative']
Category counts:
large_positive: 2
small_positive: 1
small_negative: 2
large_negative: 1
3.2 switch: Multi-Way Branching#
switch generalizes cond to multiple branches, similar to a switch/case statement.
Function signature:
switch(
index,
branches: Sequence[Callable],
*operands
)
Parameters:
index: Integer scalar selecting which branch to executebranches: Sequence of callables (at least 1)*operands: Arguments passed to the selected branch
Index handling:
Out-of-bounds indices are clamped to
[0, len(branches) - 1]Negative indices are clamped to 0
Indices >= len(branches) are clamped to len(branches) - 1
# Example 1: Simple multi-way branch
def operation_0(x):
return x + 1
def operation_1(x):
return x * 2
def operation_2(x):
return x ** 2
def operation_3(x):
return -x
operations = [operation_0, operation_1, operation_2, operation_3]
x = 5.0
for i in range(len(operations)):
result = switch(i, operations, x)
print(f"Operation {i} on {x}: {result}")
# Test clamping
print(f"\nOut of bounds (index={len(operations)}): {switch(len(operations), operations, x)}")
print(f"Out of bounds (index={-1}): {switch(-1, operations, x)}")
Operation 0 on 5.0: 6.0
Operation 1 on 5.0: 10.0
Operation 2 on 5.0: 25.0
Operation 3 on 5.0: -5.0
Out of bounds (index=4): -5.0
Out of bounds (index=-1): 6.0
# Example 2: Stateful switch - activation function selector
class ActivationSelector(brainstate.nn.Module):
"""Select and apply different activation functions."""
def __init__(self):
super().__init__()
self.usage_counts = brainstate.ShortTermState(jnp.zeros(5, dtype=jnp.int32))
def _track_usage(self, index):
counts = self.usage_counts.value
counts = counts.at[index].add(1)
self.usage_counts.value = counts
def relu(self, x):
self._track_usage(0)
return jnp.maximum(0, x)
def sigmoid(self, x):
self._track_usage(1)
return 1 / (1 + jnp.exp(-x))
def tanh(self, x):
self._track_usage(2)
return jnp.tanh(x)
def softplus(self, x):
self._track_usage(3)
return jnp.log(1 + jnp.exp(x))
def identity(self, x):
self._track_usage(4)
return x
def apply(self, index, x):
return switch(
index,
[self.relu, self.sigmoid, self.tanh, self.softplus, self.identity],
x
)
selector = ActivationSelector()
activation_names = ['ReLU', 'Sigmoid', 'Tanh', 'Softplus', 'Identity']
# Test all activations
test_input = 2.0
print(f"Input: {test_input}\n")
for i in range(len(activation_names)):
result = selector.apply(i, test_input)
print(f"{activation_names[i]:10s}: {result:.4f}")
print(f"\nUsage counts: {selector.usage_counts.value}")
Input: 2.0
ReLU : 2.0000
Sigmoid : 0.8808
Tanh : 0.9640
Softplus : 2.1269
Identity : 2.0000
Usage counts: [1 1 1 1 1]
# Example 3: Dynamic policy selection
class PolicySelector(brainstate.nn.Module):
"""Select different action policies based on state."""
def __init__(self):
super().__init__()
self.total_reward = brainstate.ShortTermState(jnp.array(0.0))
def aggressive_policy(self, state):
action = state * 2.0
reward = jnp.abs(action) * 0.5
self.total_reward.value = self.total_reward.value + reward
return {'action': action, 'reward': reward, 'policy': 'aggressive'}
def conservative_policy(self, state):
action = state * 0.5
reward = jnp.abs(action) * 1.0
self.total_reward.value = self.total_reward.value + reward
return {'action': action, 'reward': reward, 'policy': 'conservative'}
def random_policy(self, state):
action = state * 1.0
reward = jnp.abs(action) * 0.3
self.total_reward.value = self.total_reward.value + reward
return {'action': action, 'reward': reward, 'policy': 'random'}
def select_and_act(self, policy_index, state):
return switch(
policy_index,
[self.aggressive_policy, self.conservative_policy, self.random_policy],
state
)
policy_selector = PolicySelector()
# Simulate decision-making over time
states = jnp.array([1.0, -0.5, 2.0, -1.5, 0.8])
policies = jnp.array([0, 1, 0, 1, 2], dtype=jnp.int32) # policy choices
with jax.disable_jit():
print("Simulation results:")
for i, (policy_idx, state) in enumerate(zip(policies, states)):
result = policy_selector.select_and_act(policy_idx, state)
print(f"Step {i}: state={state:5.1f}, policy={result['policy']:12s}, "
f"action={result['action']:5.2f}, reward={result['reward']:.2f}")
print(f"\nTotal reward: {policy_selector.total_reward.value:.2f}")
Simulation results:
Step 0: state= 1.0, policy=aggressive , action= 2.00, reward=1.00
Step 1: state= -0.5, policy=conservative, action=-0.25, reward=0.25
Step 2: state= 2.0, policy=aggressive , action= 4.00, reward=2.00
Step 3: state= -1.5, policy=conservative, action=-0.75, reward=0.75
Step 4: state= 0.8, policy=random , action= 0.80, reward=0.24
Total reward: 4.24
3.3 ifelse: Multi-Condition If/Elif/Else#
ifelse provides a high-level interface for multi-condition branching, similar to Python’s if/elif/else.
Function signature:
ifelse(
conditions,
branches,
*operands,
check_cond: bool = True
)
Parameters:
conditions: Sequence of boolean predicates (should be mutually exclusive)branches: Sequence of callables (same length as conditions)*operands: Arguments passed to the selected branchcheck_cond: If True, verify exactly one condition is True
Common pattern:
Make the last condition True to create a default/else branch:
ifelse(
[x > 10, x > 5, True], # last condition is always True
[large_fn, medium_fn, small_fn],
x
)
# Example 1: Simple if/elif/else
def classify_number(x):
def large():
return "large"
def medium():
return "medium"
def small():
return "small"
return ifelse(
[x > 10, jnp.logical_and(x > 5, x <= 10), x <= 5], # True acts as 'else'
[large, medium, small]
)
with jax.disable_jit():
for value in [15.0, 7.0, 2.0, 10.5, 5.0]:
category = classify_number(value)
print(f"{value:5.1f} -> {category}")
15.0 -> large
7.0 -> medium
2.0 -> small
10.5 -> large
5.0 -> small
# Example 2: Stateful grade calculator
class GradeCalculator(brainstate.nn.Module):
"""Calculate letter grades and track statistics."""
def __init__(self):
super().__init__()
self.grade_counts = brainstate.ShortTermState({
'A': jnp.array(0),
'B': jnp.array(0),
'C': jnp.array(0),
'D': jnp.array(0),
'F': jnp.array(0),
})
def _record_grade(self, letter):
counts = self.grade_counts.value
counts[letter] = counts[letter] + 1
self.grade_counts.value = counts
def grade_A(self):
return self._record_grade('A')
def grade_B(self):
return self._record_grade('B')
def grade_C(self):
return self._record_grade('C')
def grade_D(self):
return self._record_grade('D')
def grade_F(self):
return self._record_grade('F')
def calculate_grade(self, score):
return ifelse(
[
score >= 90,
jnp.logical_and(score >= 80, score < 90),
jnp.logical_and(score >= 70, score < 80),
jnp.logical_and(score >= 60, score < 70),
score < 60
],
[
self.grade_A,
self.grade_B,
self.grade_C,
self.grade_D,
self.grade_F
]
)
calculator = GradeCalculator()
# Process student scores
scores = jnp.array([95, 87, 76, 82, 59, 91, 68, 45, 88, 93])
grades = [calculator.calculate_grade(score) for score in scores]
print("\nGrade distribution:")
for letter, count in calculator.grade_counts.value.items():
print(f" {letter}: {'*' * int(count)}")
Grade distribution:
A: ***
B: ***
C: *
D: *
F: **
Summary#
This tutorial covered all control flow primitives in brainstate.transform:
Loop Transformations#
scan: Fundamental loop with carry and outputsUse for: Recurrent patterns, accumulation, sequential processing
Collects outputs at each iteration
Key params:
reverse,unroll,pbar
checkpointed_scan: Memory-efficient scan with gradient checkpointingUse for: Long sequences, memory constraints during gradient computation
Key benefit: Stores only O(log_base(n)) checkpoints instead of O(n) activations during backpropagation
Trades computation (recomputation during backward pass) for memory savings
Key param:
base(checkpointing granularity)
for_loop: Simplified loop without explicit carryUse for: Simple iteration, state updates
Important: Return value at each timestep is saved and stacked into the final output array
Variadic inputs, no carry management
Output shape: stacks results along axis 0 (e.g., scalar→1D, vector→2D)
checkpointed_for_loop: Memory-efficient for loop with gradient checkpointingCombines simplicity of for_loop with memory efficiency during gradient computation
Essential for training with very long sequences
Same memory benefits as
checkpointed_scan
While Loops#
while_loop: Dynamic iteration with conditionUse for: Unknown iteration count, no gradients needed
Constraint:
cond_funmust be read-only
bounded_while_loop: While loop with maximum stepsUse for: Gradients, safety, predictable compilation
Key params:
max_steps,base
Conditional Control Flow#
cond: Binary conditional (if/else)Use for: Two-way decisions
Lazy evaluation, state-safe
switch: Multi-way branching (switch/case)Use for: Multiple branches with integer index
Index clamping for safety
ifelse: Multi-condition branching (if/elif/else)Use for: Complex conditions, default branches
Use
Truefor else branch
Key Principles#
State Safety: All APIs properly track state reads and writes
Lazy Evaluation: Conditionals only execute selected branches
JAX Compatibility: Compile to efficient JAX primitives
Output Collection:
for_loopandscancollect outputs at each iteration into the final resultMemory Efficiency: Checkpointed variants save memory during gradient computation by storing only checkpoints and recomputing intermediate activations during backpropagation
Differentiability: Most APIs support gradients (except
while_loop); checkpointed variants are essential for long sequences
These primitives enable complex control flow while maintaining BrainState’s stateful programming model and JAX’s performance benefits.