Loops and Conditionals#

This tutorial covers state-aware control flow primitives in brainstate.transform. These APIs provide JAX-compatible loops and conditionals while safely handling State objects.

We’ll explore three categories of control flow:

  1. Loop Transformations: scan, checkpointed_scan, for_loop, checkpointed_for_loop

  2. While Loops: while_loop, bounded_while_loop

  3. Conditional Control Flow: cond, switch, ifelse

Each API is designed to work seamlessly with BrainState’s state management system while maintaining JAX’s functional programming paradigm.

Imports and Setup#

import jax
import jax.numpy as jnp

import brainstate
from brainstate.transform import (
    scan,
    checkpointed_scan,
    for_loop,
    checkpointed_for_loop,
    while_loop,
    bounded_while_loop,
    cond,
    switch,
    ifelse,
)
# Import ProgressBar
from brainstate.transform import ProgressBar

1. Loop Transformations#

Loop transformations provide efficient iteration over sequences with state tracking. They compile to a single JAX primitive, reducing compilation overhead.

1.1 scan: Stateful Scanning with Carry#

scan is the fundamental loop primitive that:

  • Iterates over a sequence along the leading axis

  • Maintains a “carry” value that threads through iterations

  • Collects outputs at each step

  • Properly handles State objects

Function signature:

scan(
    f: Callable[[Carry, X], Tuple[Carry, Y]],
    init: Carry,
    xs: X,
    length: int | None = None,
    reverse: bool = False,
    unroll: int | bool = 1,
    pbar: ProgressBar | int | None = None,
) -> Tuple[Carry, Y]

Parameters:

  • f: Function of type (carry, x) -> (new_carry, output)

  • init: Initial carry value

  • xs: Sequence to iterate over (along axis 0)

  • length: Optional iteration count (inferred from xs if not provided)

  • reverse: If True, iterate in reverse order

  • unroll: Number of iterations to unroll (1=no unrolling, True=full unrolling)

  • pbar: Optional progress bar

# Example 1: Basic scan with carry
def cumsum_body(carry, x):
    """Accumulate sum and return both new carry and current sum."""
    new_carry = carry + x
    return new_carry, new_carry


xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
final_sum, cumulative_sums = scan(cumsum_body, init=0.0, xs=xs)

print("Input sequence:", xs)
print("Final sum:", final_sum)
print("Cumulative sums:", cumulative_sums)
Input sequence: [1. 2. 3. 4. 5.]
Final sum: 15.0
Cumulative sums: [ 1.  3.  6. 10. 15.]
# Example 2: Scan with stateful computation
class RunningStats(brainstate.nn.Module):
    """Maintain running mean and variance."""

    def __init__(self):
        super().__init__()
        self.count = brainstate.ShortTermState(jnp.array(0))
        self.mean = brainstate.ShortTermState(jnp.array(0.0))
        self.m2 = brainstate.ShortTermState(jnp.array(0.0))  # sum of squared differences

    def update(self, x):
        """Update statistics with new value using Welford's algorithm."""
        self.count.value = self.count.value + 1
        delta = x - self.mean.value
        self.mean.value = self.mean.value + delta / self.count.value
        delta2 = x - self.mean.value
        self.m2.value = self.m2.value + delta * delta2

        variance = self.m2.value / self.count.value
        return {'mean': self.mean.value, 'var': variance}


stats = RunningStats()


def stats_body(carry, x):
    result = stats.update(x)
    return carry, result


data = jnp.array([2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0])
_, history = scan(stats_body, init=None, xs=data)

print("Data:", data)
print("\nRunning mean:", history['mean'])
print("Running variance:", history['var'])
print("\nFinal statistics:")
print(f"  Count: {stats.count.value}")
print(f"  Mean: {stats.mean.value}")
print(f"  Variance: {stats.m2.value / stats.count.value}")
Data: [2. 4. 4. 4. 5. 5. 7. 9.]

Running mean: [2.        3.        3.3333333 3.5       3.8       4.        4.428571
 5.       ]
Running variance: [0.         1.         0.8888889  0.75       0.96000004 1.
 1.9591838  4.        ]

Final statistics:
  Count: 8
  Mean: 5.0
  Variance: 4.0
# Example 3: Reverse scan
def reverse_cumsum(carry, x):
    new_carry = carry + x
    return new_carry, new_carry


xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
_, forward_sums = scan(reverse_cumsum, 0.0, xs, reverse=False)
_, backward_sums = scan(reverse_cumsum, 0.0, xs, reverse=True)

print("Input:", xs)
print("Forward cumsum:", forward_sums)
print("Backward cumsum:", backward_sums)
Input: [1. 2. 3. 4. 5.]
Forward cumsum: [ 1.  3.  6. 10. 15.]
Backward cumsum: [15. 14. 12.  9.  5.]

Progress Bar with scan#

The pbar parameter enables progress tracking during long-running scans. You can:

  • Pass a ProgressBar instance for full control over display options

  • Pass an integer for quick setup (updates every N iterations)

  • Customize the description with static or dynamic messages

# Example 4: Progress bar with scan - simple integer freq
print("\n=== Simple progress bar (update every 20 iterations) ===")


def expensive_computation(carry, x):
    """Simulate expensive computation."""
    # Some computation
    result = carry + jnp.sin(x) * jnp.cos(x)
    return result, result


# Create long sequence
long_sequence = jnp.linspace(0, 10 * jnp.pi, 100)

# Use integer for simple progress bar (updates every 20 iterations)
final, outputs = scan(expensive_computation, init=0.0, xs=long_sequence, pbar=20)
print(f"\nFinal result: {final}")
=== Simple progress bar (update every 20 iterations) ===
Final result: -4.0076361074170563e-07
# Example 5: Progress bar with custom ProgressBar instance
print("\n=== Custom progress bar with ProgressBar ===")

# Create ProgressBar with custom settings
pbar = ProgressBar(freq=10, desc="Processing sequence")

final, outputs = scan(expensive_computation, init=0.0, xs=long_sequence, pbar=pbar)
print(f"\nCompleted! Final result: {final}")
=== Custom progress bar with ProgressBar ===
Completed! Final result: -4.0076361074170563e-07
# Example 6: Dynamic progress bar description based on loop state
print("\n=== Dynamic progress bar with loop state ===")


class OptimizationTracker(brainstate.nn.Module):
    """Track optimization progress."""

    def __init__(self):
        super().__init__()
        self.best_loss = brainstate.ShortTermState(jnp.array(float('inf')))

    def step(self, params, x):
        # Compute loss
        loss = jnp.sum((params - x) ** 2)
        # Update best
        self.best_loss.value = jnp.minimum(self.best_loss.value, loss)
        # Update parameters
        new_params = params - 0.1 * 2 * (params - x)
        return new_params, loss


tracker = OptimizationTracker()


def scan_body_with_tracking(params, x):
    return tracker.step(params, x)


# Define dynamic description
def format_progress(data):
    """Format progress with current loss and best loss."""
    return {
        "iter": data["i"],
        "loss": data["y"],
        "best": tracker.best_loss.value
    }


pbar_dynamic = ProgressBar(
    freq=15,
    desc=("Iter {iter:3d} | Loss: {loss:.4f} | Best: {best:.4f}", format_progress)
)

targets = jax.random.normal(jax.random.PRNGKey(42), (100,))
init_params = jnp.array(0.0)

final_params, loss_history = scan(
    scan_body_with_tracking,
    init=init_params,
    xs=targets,
    pbar=pbar_dynamic
)

print(f"\nOptimization completed!")
print(f"Final parameters: {final_params}")
print(f"Final loss: {loss_history[-1]}")
print(f"Best loss achieved: {tracker.best_loss.value}")
=== Dynamic progress bar with loop state ===
Optimization completed!
Final parameters: -0.04794257506728172
Final loss: 1.41942298412323
Best loss achieved: 3.858334093820304e-05

1.2 checkpointed_scan: Memory-Efficient Scanning#

checkpointed_scan is a memory-optimized version of scan that uses gradient checkpointing. This is crucial for:

  • Long sequences where storing all intermediate activations is memory-prohibitive during gradient computation

  • Trading computation time for memory during backpropagation

  • Memory efficiency is achieved by only storing checkpoints at regular intervals during the forward pass, then recomputing intermediate values during the backward pass when needed

Function signature:

checkpointed_scan(
    f: Callable[[Carry, X], Tuple[Carry, Y]],
    init: Carry,
    xs: X,
    length: Optional[int] = None,
    base: int = 16,
    pbar: Optional[ProgressBar | int] = None,
) -> Tuple[Carry, Y]

Key parameter:

  • base: Checkpointing base (default=16). Smaller values save more memory but increase recomputation during backward pass. The implementation uses a hierarchical checkpointing scheme where max_steps = base^k for some k.

Memory savings during gradient computation:

  • Regular scan: Stores all intermediate activations → O(n) memory for sequence length n

  • checkpointed_scan: Stores only checkpoints → O(log_base(n)) memory

  • During backward pass: Recomputes intermediate values between checkpoints as needed

# Example: Memory-efficient scan for gradient computation
class RecurrentCell(brainstate.nn.Module):
    """Simple recurrent cell with hidden state."""

    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.weight = brainstate.ParamState(jax.random.normal(
            jax.random.PRNGKey(0), (hidden_size, hidden_size)
        ))

    def step(self, hidden, x):
        """Single recurrent step."""
        new_hidden = jnp.tanh(jnp.dot(self.weight.value, hidden) + x)
        return new_hidden


# Create a cell and input sequence
cell = RecurrentCell(hidden_size=32)
sequence_length = 100
inputs = jax.random.normal(jax.random.PRNGKey(1), (sequence_length, 32))


def rnn_body(hidden, x):
    new_hidden = cell.step(hidden, x)
    return new_hidden, new_hidden


# Use checkpointed scan for memory efficiency during gradient computation
init_hidden = jnp.zeros(32)
final_hidden, all_hiddens = checkpointed_scan(
    rnn_body,
    init=init_hidden,
    xs=inputs,
    base=8  # Checkpoint every 8 steps
)

print(f"Sequence length: {sequence_length}")
print(f"Hidden size: {cell.hidden_size}")
print(f"Final hidden shape: {final_hidden.shape}")
print(f"All hiddens shape: {all_hiddens.shape}")
print(f"\nCheckpointing configuration:")
print(f"  Base: 8 (stores checkpoint every 8 steps)")
print(f"  Memory saved: Stores ~{sequence_length // 8} checkpoints instead of {sequence_length} activations")
print(f"  During backprop: Recomputes activations between checkpoints as needed")
Sequence length: 100
Hidden size: 32
Final hidden shape: (32,)
All hiddens shape: (100, 32)

Checkpointing configuration:
  Base: 8 (stores checkpoint every 8 steps)
  Memory saved: Stores ~12 checkpoints instead of 100 activations
  During backprop: Recomputes activations between checkpoints as needed

Progress Bar with checkpointed_scan#

checkpointed_scan also supports progress bars, which is especially useful for very long sequences where you want to monitor progress.

# Example: Progress bar with checkpointed_scan
print("\n=== Checkpointed scan with progress bar ===")


class LongRunningComputation(brainstate.nn.Module):
    """Simulate a long-running computation."""

    def __init__(self):
        super().__init__()
        self.total_ops = brainstate.ShortTermState(jnp.array(0))

    def process(self, state, x):
        self.total_ops.value = self.total_ops.value + 1
        # Some computation
        new_state = state + jnp.tanh(x)
        output = jnp.sin(new_state) * jnp.cos(x)
        return new_state, output


long_comp = LongRunningComputation()


def body(state, x):
    return long_comp.process(state, x)


# Long sequence
very_long_sequence = jnp.linspace(0, 20 * jnp.pi, 500)

# Progress bar that updates every 50 iterations
pbar_checkpointed = ProgressBar(
    freq=50,
    desc="Checkpointed scan progress"
)

final_state, results = checkpointed_scan(
    body,
    init=0.0,
    xs=very_long_sequence,
    base=10,
    pbar=pbar_checkpointed
)

print(f"\nProcessed {long_comp.total_ops.value} operations")
print(f"Final state: {final_state}")
print(f"Results shape: {results.shape}")
=== Checkpointed scan with progress bar ===
Processed 500 operations
Final state: 493.9846496582031
Results shape: (500,)

1.3 for_loop: Simplified Loop Without Carry#

for_loop provides a simpler interface when you don’t need an explicit carry value. It:

  • Accepts variadic arguments that are sliced along axis 0

  • Collects and returns outputs from each iteration - the return value from your function at each timestep is saved and stacked into the final output array

  • Internally uses scan with None as the carry

Function signature:

for_loop(
    f: Callable[..., Y],
    *xs,
    length: Optional[int] = None,
    reverse: bool = False,
    unroll: int | bool = 1,
    pbar: Optional[ProgressBar | int] = None
) -> Y

Key differences from scan:

  • Function signature is (*xs) -> output instead of (carry, x) -> (carry, output)

  • No carry value to manage

  • Important: The return value at each iteration is collected and stacked along axis 0 to form the final output. This means if your function returns a scalar at each step, for_loop returns a 1D array; if it returns a vector of shape (d,), the output will be shape (n, d) where n is the number of iterations.

# Example 1: Understanding output collection in for_loop
def compute(x, y, z):
    """Combine three inputs."""
    return x * y + z


xs = jnp.array([1.0, 2.0, 3.0, 4.0])
ys = jnp.array([2.0, 3.0, 4.0, 5.0])
zs = jnp.array([0.5, 1.0, 1.5, 2.0])

# for_loop collects the output from EACH iteration
results = for_loop(compute, xs, ys, zs)

print("x:", xs)
print("y:", ys)
print("z:", zs)
print("x * y + z:", results)
print(f"\nNotice: for_loop collected {len(results)} outputs (one per iteration)")
print(f"Each element results[i] = xs[i] * ys[i] + zs[i]")
print(f"Output shape: {results.shape} (stacked along axis 0)")
x: [1. 2. 3. 4.]
y: [2. 3. 4. 5.]
z: [0.5 1.  1.5 2. ]
x * y + z: [ 2.5  7.  13.5 22. ]

Notice: for_loop collected 4 outputs (one per iteration)
Each element results[i] = xs[i] * ys[i] + zs[i]
Output shape: (4,) (stacked along axis 0)
# Example 2: Stateful for_loop
class Accumulator(brainstate.nn.Module):
    """Simple accumulator that tracks total and count."""

    def __init__(self):
        super().__init__()
        self.total = brainstate.ShortTermState(jnp.array(0.0))
        self.count = brainstate.ShortTermState(jnp.array(0))

    def process(self, x):
        self.total.value = self.total.value + x
        self.count.value = self.count.value + 1
        return self.total.value / self.count.value  # running average


acc = Accumulator()

data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
running_averages = for_loop(acc.process, data)

print("Data:", data)
print("Running averages:", running_averages)
print(f"\nFinal state: total={acc.total.value}, count={acc.count.value}")
print(f"Final average: {acc.total.value / acc.count.value}")
Data: [1. 2. 3. 4. 5. 6.]
Running averages: [1.  1.5 2.  2.5 3.  3.5]

Final state: total=21.0, count=6
Final average: 3.5

Progress Bar with for_loop#

for_loop also supports progress bars. This is particularly useful when processing large batches of data.

# Example 3: Progress bar with for_loop - simple case
print("\n=== For loop with progress bar ===")


class DataProcessor(brainstate.nn.Module):
    """Process data with progress tracking."""

    def __init__(self):
        super().__init__()
        self.processed_count = brainstate.ShortTermState(jnp.array(0))
        self.sum_val = brainstate.ShortTermState(jnp.array(0.0))

    def process_item(self, x):
        self.processed_count.value = self.processed_count.value + 1
        self.sum_val.value = self.sum_val.value + x
        # Simulate some processing
        result = jnp.exp(x) / (1 + jnp.exp(x))  # sigmoid
        return result


processor = DataProcessor()

# Create dataset
dataset = jax.random.normal(jax.random.PRNGKey(123), (200,))

# Use simple integer for progress updates
processed = for_loop(processor.process_item, dataset, pbar=25)

print(f"\nProcessed {processor.processed_count.value} items")
print(f"Sum of inputs: {processor.sum_val.value}")
print(f"Processed data shape: {processed.shape}")
=== For loop with progress bar ===
Processed 200 items
Sum of inputs: 9.437446594238281
Processed data shape: (200,)
# Example 4: For loop with dynamic progress description
print("\n=== For loop with dynamic progress description ===")


class BatchProcessor(brainstate.nn.Module):
    """Process batches with statistics."""

    def __init__(self):
        super().__init__()
        self.mean = brainstate.ShortTermState(jnp.array(0.0))
        self.variance = brainstate.ShortTermState(jnp.array(0.0))
        self.count = brainstate.ShortTermState(jnp.array(0))

    def update(self, x):
        self.count.value = self.count.value + 1
        delta = x - self.mean.value
        self.mean.value = self.mean.value + delta / self.count.value
        self.variance.value = self.variance.value + delta * (x - self.mean.value)
        return x ** 2


batch_proc = BatchProcessor()


def format_batch_progress(data):
    """Show current statistics."""
    return {
        "n": data["i"],
        "mean": batch_proc.mean.value,
        "var": batch_proc.variance.value / jnp.maximum(batch_proc.count.value, 1)
    }


pbar_batch = ProgressBar(
    freq=20,
    desc=("Batch {n:3d} | Mean: {mean:+.3f} | Var: {var:.3f}", format_batch_progress)
)

batch_data = jax.random.normal(jax.random.PRNGKey(456), (150,)) * 2.0 + 1.0

squared = for_loop(batch_proc.update, batch_data, pbar=pbar_batch)

print(f"\nFinal statistics:")
print(f"  Mean: {batch_proc.mean.value}")
print(f"  Variance: {batch_proc.variance.value / batch_proc.count.value}")
print(f"  Count: {batch_proc.count.value}")
=== For loop with dynamic progress description ===
Final statistics:
  Mean: 0.9223809242248535
  Variance: 3.488231897354126
  Count: 150

1.4 checkpointed_for_loop: Memory-Efficient For Loop#

The checkpointed version of for_loop combines the simplicity of for_loop with the memory efficiency of checkpointing during gradient computation.

Function signature:

checkpointed_for_loop(
    f: Callable[..., Y],
    *xs: X,
    length: Optional[int] = None,
    base: int = 16,
    pbar: Optional[ProgressBar | int] = None,
) -> Y

Memory efficiency during gradient computation:

  • Like checkpointed_scan, this variant significantly reduces memory usage during backpropagation

  • Essential for training models with very long sequences where storing all intermediate activations would cause out-of-memory errors

  • The base parameter controls the memory/computation tradeoff: smaller base = less memory but more recomputation

# Example: Processing long sequence with state
class ExpMovingAverage(brainstate.nn.Module):
    """Exponential moving average."""

    def __init__(self, alpha=0.1):
        super().__init__()
        self.alpha = alpha
        self.ema = brainstate.ShortTermState(jnp.array(0.0))
        self.initialized = brainstate.ShortTermState(jnp.array(False))

    def update(self, x):
        # Initialize with first value
        self.ema.value = jnp.where(
            self.initialized.value,
            self.alpha * x + (1 - self.alpha) * self.ema.value,
            x
        )
        self.initialized.value = True
        return self.ema.value


ema = ExpMovingAverage(alpha=0.3)

# Generate noisy signal
signal = jnp.sin(jnp.linspace(0, 4 * jnp.pi, 200)) + 0.2 * brainstate.random.normal(size=(200,))

# Process with checkpointed for_loop
smoothed = checkpointed_for_loop(ema.update, signal, base=10)

print(f"Signal length: {len(signal)}")
print(f"Smoothed signal shape: {smoothed.shape}")
print(f"Original signal range: [{signal.min():.3f}, {signal.max():.3f}]")
print(f"Smoothed signal range: [{smoothed.min():.3f}, {smoothed.max():.3f}]")
Signal length: 200
Smoothed signal shape: (200,)
Original signal range: [-1.401, 1.408]
Smoothed signal range: [-1.131, 1.200]

Progress Bar with checkpointed_for_loop#

checkpointed_for_loop supports progress bars to help track processing of very long sequences.

# Example: Progress bar with checkpointed_for_loop
print("\n=== Checkpointed for loop with progress bar ===")


class StreamProcessor(brainstate.nn.Module):
    """Process streaming data."""

    def __init__(self, momentum=0.9):
        super().__init__()
        self.momentum = momentum
        self.running_avg = brainstate.ShortTermState(jnp.array(0.0))

    def process(self, x):
        # Update exponential moving average
        self.running_avg.value = (
            self.momentum * self.running_avg.value + (1 - self.momentum) * x
        )
        return self.running_avg.value


stream_proc = StreamProcessor(momentum=0.95)

# Generate long data stream
data_stream = jax.random.normal(jax.random.PRNGKey(789), (1000,))

# Progress bar with count parameter (updates exactly 10 times)
pbar_stream = ProgressBar(
    count=10,
    desc="Processing data stream"
)

smoothed_stream = checkpointed_for_loop(
    stream_proc.process,
    data_stream,
    base=20,
    pbar=pbar_stream
)

print(f"\nStream processed!")
print(f"Final running average: {stream_proc.running_avg.value}")
print(f"Smoothed stream shape: {smoothed_stream.shape}")
print(f"First 5 values: {smoothed_stream[:5]}")
print(f"Last 5 values: {smoothed_stream[-5:]}")
=== Checkpointed for loop with progress bar ===
Stream processed!
Final running average: 0.050981856882572174
Smoothed stream shape: (1000,)
First 5 values: [ 0.03755957 -0.01195641  0.0011875   0.03787947  0.04933156]
Last 5 values: [ 0.10296391  0.07546234 -0.00500105  0.06068815  0.05098186]

1.5 Comparison: scan vs for_loop#

When to use each:

Use scan when:

  • You need to thread a carry value through iterations

  • Implementing recurrent patterns (RNNs, state machines)

  • You want explicit control over the accumulator

Use for_loop when:

  • No carry value is needed

  • Processing independent items with side effects (state updates)

  • Simpler, more Pythonic syntax is preferred

# Comparison example: Computing powers of 2

# Using scan: carry explicitly tracks the power
def scan_version(n):
    def body(carry, _):
        return carry * 2, carry

    _, powers = scan(body, init=1, xs=jnp.arange(n))
    return powers


# Using for_loop with state: state tracks the power
class PowerTracker(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.current = brainstate.ShortTermState(jnp.array(1))

    def next_power(self, _):
        result = self.current.value
        self.current.value = self.current.value * 2
        return result


def forloop_version(n):
    tracker = PowerTracker()
    return for_loop(tracker.next_power, jnp.arange(n))


n = 10
print(f"Powers of 2 (first {n} values):")
print("scan result:    ", scan_version(n))
print("for_loop result:", forloop_version(n))
Powers of 2 (first 10 values):
scan result:     [  1   2   4   8  16  32  64 128 256 512]
for_loop result: [  1   2   4   8  16  32  64 128 256 512]

2. While Loops#

While loops provide conditional iteration where the number of iterations is not known in advance.

2.1 while_loop: Dynamic Conditional Iteration#

while_loop executes a body function repeatedly while a condition remains true. This is the stateful version of jax.lax.while_loop.

Function signature:

while_loop(
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T
) -> T

Parameters:

  • cond_fun: Function that returns True to continue looping

  • body_fun: Function that updates the loop value

  • init_val: Initial loop value

Important constraints:

  • cond_fun cannot modify state (read-only)

  • Loop value must maintain fixed shape and dtype

  • Not reverse-mode differentiable (use bounded_while_loop instead)

# Example 1: Simple while loop - find first power of 2 above threshold
def find_power_of_2_above(threshold):
    def cond_fn(val):
        return val < threshold

    def body(val):
        return val * 2

    return while_loop(cond_fn, body, init_val=1)


threshold = 1000
result = find_power_of_2_above(threshold)
print(f"First power of 2 above {threshold}: {result}")
First power of 2 above 1000: 1024
# Example 2: Stateful while loop - iterative refinement
class IterativeRefiner(brainstate.nn.Module):
    """Iteratively refine an estimate using Newton's method."""

    def __init__(self, target):
        super().__init__()
        self.target = target
        self.iterations = brainstate.ShortTermState(jnp.array(0))

    def refine(self, x):
        """Newton's method step for computing sqrt(target)."""
        self.iterations.value = self.iterations.value + 1
        return 0.5 * (x + self.target / x)


# Compute square root of 2 using Newton's method
refiner = IterativeRefiner(target=2.0)


def cond_f(x):
    # Continue until error is small enough
    return jnp.abs(x * x - refiner.target) > 1e-6


def body(x):
    return refiner.refine(x)


result = while_loop(cond_f, body, init_val=1.0)

print(f"Computing sqrt(2)...")
print(f"Result: {result}")
print(f"Actual sqrt(2): {jnp.sqrt(2.0)}")
print(f"Error: {jnp.abs(result - jnp.sqrt(2.0))}")
print(f"Iterations: {refiner.iterations.value}")
Computing sqrt(2)...
Result: 1.4142135381698608
Actual sqrt(2): 1.4142135381698608
Error: 0.0
Iterations: 4
# Example 3: Complex loop value (pytree)
class Collatz(brainstate.nn.Module):
    """Track Collatz sequence statistics."""

    def __init__(self):
        super().__init__()
        self.max_value = brainstate.ShortTermState(jnp.array(0))

    def step(self, n):
        self.max_value.value = jnp.maximum(self.max_value.value, n)
        return jnp.where(n % 2 == 0, n // 2, 3 * n + 1)


collatz = Collatz()


def collatz_cond(state):
    n, steps = state
    return n > 1


def collatz_body(state):
    n, steps = state
    return collatz.step(n), steps + 1


start_value = 27
final_n, total_steps = while_loop(
    collatz_cond,
    collatz_body,
    init_val=(start_value, 0)
)

print(f"Collatz sequence starting from {start_value}:")
print(f"  Converged to: {final_n}")
print(f"  Steps taken: {total_steps}")
print(f"  Maximum value reached: {collatz.max_value.value}")
Collatz sequence starting from 27:
  Converged to: 1
  Steps taken: 111
  Maximum value reached: 9232

2.2 bounded_while_loop: While Loop with Maximum Steps#

bounded_while_loop adds a maximum iteration limit to while loops. This is important for:

  • Preventing infinite loops

  • Enabling reverse-mode differentiation (unlike while_loop)

  • Providing compilation time guarantees

Function signature:

bounded_while_loop(
    cond_fun: Callable[[T], BooleanNumeric],
    body_fun: Callable[[T], T],
    init_val: T,
    *,
    max_steps: int,
    base: int = 16,
)

Key parameters:

  • max_steps: Maximum number of iterations before termination

  • base: Compilation/runtime tradeoff (default=16)

    • Larger base = faster compilation, slightly slower runtime

    • Smaller base = slower compilation, faster runtime

    • Compile time scales with math.ceil(math.log(max_steps, base))

# Example 1: Gradient descent with bounded iterations
class GradientDescent(brainstate.nn.Module):
    """Simple gradient descent optimizer."""

    def __init__(self, learning_rate=0.1):
        super().__init__()
        self.lr = learning_rate
        self.steps = brainstate.ShortTermState(jnp.array(0))

    def step(self, x):
        # Gradient of f(x) = (x - 3)^2
        grad = 2 * (x - 3.0)
        self.steps.value = self.steps.value + 1
        return x - self.lr * grad


optimizer = GradientDescent(learning_rate=0.1)


def converged(x):
    # Continue if far from optimum
    return jnp.abs(x - 3.0) > 1e-4


result = bounded_while_loop(
    converged,
    optimizer.step,
    init_val=0.0,
    max_steps=100
)

print(f"Minimizing f(x) = (x - 3)^2")
print(f"Starting from x = 0.0")
print(f"Final x: {result}")
print(f"Target x: 3.0")
print(f"Error: {jnp.abs(result - 3.0)}")
print(f"Iterations used: {optimizer.steps.value} / 100")
Minimizing f(x) = (x - 3)^2
Starting from x = 0.0
Final x: 3.001088857650757
Target x: 3.0
Error: 0.001088857650756836
Iterations used: 252 / 100
# Example 2: Comparing different base values
class Counter(brainstate.nn.Module):
    def __init__(self):
        super().__init__()
        self.count = brainstate.ShortTermState(jnp.array(0))

    def increment(self, x):
        self.count.value = self.count.value + 1
        return x + 1


def compare_base_values():
    max_steps = 100

    for base in [2, 8, 16]:
        counter = Counter()

        result = bounded_while_loop(
            lambda x: x < 50,
            counter.increment,
            init_val=0,
            max_steps=max_steps,
            base=base
        )

        recursion_depth = jnp.ceil(jnp.log(max_steps) / jnp.log(base))
        print(f"Base {base:2d}: result={result}, iterations={counter.count.value}, "
              f"recursion_depth≈{int(recursion_depth)}")


compare_base_values()
Base  2: result=206, iterations=50, recursion_depth≈7
Base  8: result=3746, iterations=50, recursion_depth≈3
Base 16: result=3346, iterations=50, recursion_depth≈2
# Example 3: Differentiable bounded_while_loop
def smooth_threshold(x, threshold=5.0, lr=0.5, max_steps=20):
    """Smoothly approach threshold using gradient descent."""

    def cond_fn(val):
        return val < threshold - 0.1

    def body(val):
        # Gradient of loss = (val - threshold)^2
        grad = 2 * (val - threshold)
        return val - lr * grad

    return bounded_while_loop(cond_fn, body, x, max_steps=max_steps)


# Compute gradient
x = 0.0
value, grad = jax.value_and_grad(smooth_threshold)(x)

print(f"Input: {x}")
print(f"Output: {value}")
print(f"Gradient: {grad}")
print(f"\nbounded_while_loop is differentiable!")
Input: 0.0
Output: 4085.0
Gradient: 0.0

bounded_while_loop is differentiable!

2.3 Comparison: while_loop vs bounded_while_loop#

Use while_loop when:

  • Number of iterations is truly unknown

  • Not computing gradients

  • Want standard JAX while loop semantics

Use bounded_while_loop when:

  • Need gradient computation

  • Want safety against infinite loops

  • Can provide reasonable upper bound on iterations

  • Need predictable compilation characteristics

3. Conditional Control Flow#

Conditional primitives enable branching logic that compiles efficiently and handles state properly.

3.1 cond: Binary Conditional (If/Else)#

cond selectively executes one of two branches based on a boolean predicate. This is the stateful version of jax.lax.cond.

Function signature:

cond(
    pred,
    true_fun: Callable,
    false_fun: Callable,
    *operands
)

Parameters:

  • pred: Boolean scalar (or numeric, where non-zero is True)

  • true_fun: Function called when pred is True

  • false_fun: Function called when pred is False

  • *operands: Arguments passed to the selected function

Key properties:

  • Only the selected branch is executed (lazy evaluation)

  • Both branches must return the same pytree structure

  • State modifications in branches are properly tracked

# Example 1: Simple conditional
def positive_branch(x):
    return x ** 2


def negative_branch(x):
    return -x


for value in [-5.0, 3.0, 0.0]:
    result = cond(value >= 0, positive_branch, negative_branch, value)
    print(f"cond({value} >= 0): {result}")
cond(-5.0 >= 0): 5.0
cond(3.0 >= 0): 9.0
cond(0.0 >= 0): 0.0
# Example 2: Stateful conditional
class BranchTracker(brainstate.nn.Module):
    """Track which branches were taken."""

    def __init__(self):
        super().__init__()
        self.true_count = brainstate.ShortTermState(jnp.array(0))
        self.false_count = brainstate.ShortTermState(jnp.array(0))

    def true_branch(self, x):
        self.true_count.value = self.true_count.value + 1
        return x * 2

    def false_branch(self, x):
        self.false_count.value = self.false_count.value + 1
        return x / 2


tracker = BranchTracker()

# Test multiple values
values = jnp.array([1.0, -2.0, 3.0, -4.0, 5.0])
results = []

for v in values:
    result = cond(v > 0, tracker.true_branch, tracker.false_branch, v)
    results.append(result)

print("Values:", values)
print("Results:", jnp.array(results))
print(f"\nBranch statistics:")
print(f"  True branch taken: {tracker.true_count.value} times")
print(f"  False branch taken: {tracker.false_count.value} times")
Values: [ 1. -2.  3. -4.  5.]
Results: [ 2. -1.  6. -2. 10.]

Branch statistics:
  True branch taken: 3 times
  False branch taken: 2 times
# Example 3: Nested conditionals
class Classifier(brainstate.nn.Module):
    """Classify numbers into categories."""

    def __init__(self):
        super().__init__()
        self.classification_counts = brainstate.ShortTermState({
            'large_positive': jnp.array(0),
            'small_positive': jnp.array(0),
            'small_negative': jnp.array(0),
            'large_negative': jnp.array(0),
        })

    def classify_positive(self, x):
        def large(x):
            counts = self.classification_counts.value
            counts['large_positive'] = counts['large_positive'] + 1
            self.classification_counts.value = counts
            return 'large_positive'

        def small(x):
            counts = self.classification_counts.value
            counts['small_positive'] = counts['small_positive'] + 1
            self.classification_counts.value = counts
            return 'small_positive'

        return cond(x > 5.0, large, small, x)

    def classify_negative(self, x):
        def small(x):
            counts = self.classification_counts.value
            counts['small_negative'] = counts['small_negative'] + 1
            self.classification_counts.value = counts
            return 'small_negative'

        def large(x):
            counts = self.classification_counts.value
            counts['large_negative'] = counts['large_negative'] + 1
            self.classification_counts.value = counts
            return 'large_negative'

        return cond(x > -5.0, small, large, x)

    def classify(self, x):
        return cond(x >= 0, self.classify_positive, self.classify_negative, x)


classifier = Classifier()

with jax.disable_jit():
    test_values = jnp.array([10.0, 2.0, -3.0, -8.0, 7.0, -1.0])
    classifications = [classifier.classify(v) for v in test_values]

    print("Values:", test_values)
    print("Classifications:", classifications)
    print("\nCategory counts:")
    for category, count in classifier.classification_counts.value.items():
        print(f"  {category}: {count}")
Values: [10.  2. -3. -8.  7. -1.]
Classifications: ['large_positive', 'small_positive', 'small_negative', 'large_negative', 'large_positive', 'small_negative']

Category counts:
  large_positive: 2
  small_positive: 1
  small_negative: 2
  large_negative: 1

3.2 switch: Multi-Way Branching#

switch generalizes cond to multiple branches, similar to a switch/case statement.

Function signature:

switch(
    index,
    branches: Sequence[Callable],
    *operands
)

Parameters:

  • index: Integer scalar selecting which branch to execute

  • branches: Sequence of callables (at least 1)

  • *operands: Arguments passed to the selected branch

Index handling:

  • Out-of-bounds indices are clamped to [0, len(branches) - 1]

  • Negative indices are clamped to 0

  • Indices >= len(branches) are clamped to len(branches) - 1

# Example 1: Simple multi-way branch
def operation_0(x):
    return x + 1


def operation_1(x):
    return x * 2


def operation_2(x):
    return x ** 2


def operation_3(x):
    return -x


operations = [operation_0, operation_1, operation_2, operation_3]

x = 5.0
for i in range(len(operations)):
    result = switch(i, operations, x)
    print(f"Operation {i} on {x}: {result}")

# Test clamping
print(f"\nOut of bounds (index={len(operations)}): {switch(len(operations), operations, x)}")
print(f"Out of bounds (index={-1}): {switch(-1, operations, x)}")
Operation 0 on 5.0: 6.0
Operation 1 on 5.0: 10.0
Operation 2 on 5.0: 25.0
Operation 3 on 5.0: -5.0

Out of bounds (index=4): -5.0
Out of bounds (index=-1): 6.0
# Example 2: Stateful switch - activation function selector
class ActivationSelector(brainstate.nn.Module):
    """Select and apply different activation functions."""

    def __init__(self):
        super().__init__()
        self.usage_counts = brainstate.ShortTermState(jnp.zeros(5, dtype=jnp.int32))

    def _track_usage(self, index):
        counts = self.usage_counts.value
        counts = counts.at[index].add(1)
        self.usage_counts.value = counts

    def relu(self, x):
        self._track_usage(0)
        return jnp.maximum(0, x)

    def sigmoid(self, x):
        self._track_usage(1)
        return 1 / (1 + jnp.exp(-x))

    def tanh(self, x):
        self._track_usage(2)
        return jnp.tanh(x)

    def softplus(self, x):
        self._track_usage(3)
        return jnp.log(1 + jnp.exp(x))

    def identity(self, x):
        self._track_usage(4)
        return x

    def apply(self, index, x):
        return switch(
            index,
            [self.relu, self.sigmoid, self.tanh, self.softplus, self.identity],
            x
        )


selector = ActivationSelector()
activation_names = ['ReLU', 'Sigmoid', 'Tanh', 'Softplus', 'Identity']

# Test all activations
test_input = 2.0
print(f"Input: {test_input}\n")

for i in range(len(activation_names)):
    result = selector.apply(i, test_input)
    print(f"{activation_names[i]:10s}: {result:.4f}")

print(f"\nUsage counts: {selector.usage_counts.value}")
Input: 2.0

ReLU      : 2.0000
Sigmoid   : 0.8808
Tanh      : 0.9640
Softplus  : 2.1269
Identity  : 2.0000

Usage counts: [1 1 1 1 1]
# Example 3: Dynamic policy selection
class PolicySelector(brainstate.nn.Module):
    """Select different action policies based on state."""

    def __init__(self):
        super().__init__()
        self.total_reward = brainstate.ShortTermState(jnp.array(0.0))

    def aggressive_policy(self, state):
        action = state * 2.0
        reward = jnp.abs(action) * 0.5
        self.total_reward.value = self.total_reward.value + reward
        return {'action': action, 'reward': reward, 'policy': 'aggressive'}

    def conservative_policy(self, state):
        action = state * 0.5
        reward = jnp.abs(action) * 1.0
        self.total_reward.value = self.total_reward.value + reward
        return {'action': action, 'reward': reward, 'policy': 'conservative'}

    def random_policy(self, state):
        action = state * 1.0
        reward = jnp.abs(action) * 0.3
        self.total_reward.value = self.total_reward.value + reward
        return {'action': action, 'reward': reward, 'policy': 'random'}

    def select_and_act(self, policy_index, state):
        return switch(
            policy_index,
            [self.aggressive_policy, self.conservative_policy, self.random_policy],
            state
        )


policy_selector = PolicySelector()

# Simulate decision-making over time
states = jnp.array([1.0, -0.5, 2.0, -1.5, 0.8])
policies = jnp.array([0, 1, 0, 1, 2], dtype=jnp.int32)  # policy choices

with jax.disable_jit():
    print("Simulation results:")
    for i, (policy_idx, state) in enumerate(zip(policies, states)):
        result = policy_selector.select_and_act(policy_idx, state)
        print(f"Step {i}: state={state:5.1f}, policy={result['policy']:12s}, "
              f"action={result['action']:5.2f}, reward={result['reward']:.2f}")

    print(f"\nTotal reward: {policy_selector.total_reward.value:.2f}")
Simulation results:
Step 0: state=  1.0, policy=aggressive  , action= 2.00, reward=1.00
Step 1: state= -0.5, policy=conservative, action=-0.25, reward=0.25
Step 2: state=  2.0, policy=aggressive  , action= 4.00, reward=2.00
Step 3: state= -1.5, policy=conservative, action=-0.75, reward=0.75
Step 4: state=  0.8, policy=random      , action= 0.80, reward=0.24

Total reward: 4.24

3.3 ifelse: Multi-Condition If/Elif/Else#

ifelse provides a high-level interface for multi-condition branching, similar to Python’s if/elif/else.

Function signature:

ifelse(
    conditions,
    branches,
    *operands,
    check_cond: bool = True
)

Parameters:

  • conditions: Sequence of boolean predicates (should be mutually exclusive)

  • branches: Sequence of callables (same length as conditions)

  • *operands: Arguments passed to the selected branch

  • check_cond: If True, verify exactly one condition is True

Common pattern: Make the last condition True to create a default/else branch:

ifelse(
    [x > 10, x > 5, True],  # last condition is always True
    [large_fn, medium_fn, small_fn],
    x
)
# Example 1: Simple if/elif/else
def classify_number(x):
    def large():
        return "large"

    def medium():
        return "medium"

    def small():
        return "small"

    return ifelse(
        [x > 10, jnp.logical_and(x > 5, x <= 10), x <= 5],  # True acts as 'else'
        [large, medium, small]
    )


with jax.disable_jit():
    for value in [15.0, 7.0, 2.0, 10.5, 5.0]:
        category = classify_number(value)
        print(f"{value:5.1f} -> {category}")
 15.0 -> large
  7.0 -> medium
  2.0 -> small
 10.5 -> large
  5.0 -> small
# Example 2: Stateful grade calculator
class GradeCalculator(brainstate.nn.Module):
    """Calculate letter grades and track statistics."""

    def __init__(self):
        super().__init__()
        self.grade_counts = brainstate.ShortTermState({
            'A': jnp.array(0),
            'B': jnp.array(0),
            'C': jnp.array(0),
            'D': jnp.array(0),
            'F': jnp.array(0),
        })

    def _record_grade(self, letter):
        counts = self.grade_counts.value
        counts[letter] = counts[letter] + 1
        self.grade_counts.value = counts

    def grade_A(self):
        return self._record_grade('A')

    def grade_B(self):
        return self._record_grade('B')

    def grade_C(self):
        return self._record_grade('C')

    def grade_D(self):
        return self._record_grade('D')

    def grade_F(self):
        return self._record_grade('F')

    def calculate_grade(self, score):
        return ifelse(
            [
                score >= 90,
                jnp.logical_and(score >= 80, score < 90),
                jnp.logical_and(score >= 70, score < 80),
                jnp.logical_and(score >= 60, score < 70),
                score < 60
            ],
            [
                self.grade_A,
                self.grade_B,
                self.grade_C,
                self.grade_D,
                self.grade_F
            ]
        )


calculator = GradeCalculator()

# Process student scores
scores = jnp.array([95, 87, 76, 82, 59, 91, 68, 45, 88, 93])
grades = [calculator.calculate_grade(score) for score in scores]

print("\nGrade distribution:")
for letter, count in calculator.grade_counts.value.items():
    print(f"  {letter}: {'*' * int(count)}")
Grade distribution:
  A: ***
  B: ***
  C: *
  D: *
  F: **

Summary#

This tutorial covered all control flow primitives in brainstate.transform:

Loop Transformations#

  • scan: Fundamental loop with carry and outputs

    • Use for: Recurrent patterns, accumulation, sequential processing

    • Collects outputs at each iteration

    • Key params: reverse, unroll, pbar

  • checkpointed_scan: Memory-efficient scan with gradient checkpointing

    • Use for: Long sequences, memory constraints during gradient computation

    • Key benefit: Stores only O(log_base(n)) checkpoints instead of O(n) activations during backpropagation

    • Trades computation (recomputation during backward pass) for memory savings

    • Key param: base (checkpointing granularity)

  • for_loop: Simplified loop without explicit carry

    • Use for: Simple iteration, state updates

    • Important: Return value at each timestep is saved and stacked into the final output array

    • Variadic inputs, no carry management

    • Output shape: stacks results along axis 0 (e.g., scalar→1D, vector→2D)

  • checkpointed_for_loop: Memory-efficient for loop with gradient checkpointing

    • Combines simplicity of for_loop with memory efficiency during gradient computation

    • Essential for training with very long sequences

    • Same memory benefits as checkpointed_scan

While Loops#

  • while_loop: Dynamic iteration with condition

    • Use for: Unknown iteration count, no gradients needed

    • Constraint: cond_fun must be read-only

  • bounded_while_loop: While loop with maximum steps

    • Use for: Gradients, safety, predictable compilation

    • Key params: max_steps, base

Conditional Control Flow#

  • cond: Binary conditional (if/else)

    • Use for: Two-way decisions

    • Lazy evaluation, state-safe

  • switch: Multi-way branching (switch/case)

    • Use for: Multiple branches with integer index

    • Index clamping for safety

  • ifelse: Multi-condition branching (if/elif/else)

    • Use for: Complex conditions, default branches

    • Use True for else branch

Key Principles#

  1. State Safety: All APIs properly track state reads and writes

  2. Lazy Evaluation: Conditionals only execute selected branches

  3. JAX Compatibility: Compile to efficient JAX primitives

  4. Output Collection: for_loop and scan collect outputs at each iteration into the final result

  5. Memory Efficiency: Checkpointed variants save memory during gradient computation by storing only checkpoints and recomputing intermediate activations during backpropagation

  6. Differentiability: Most APIs support gradients (except while_loop); checkpointed variants are essential for long sequences

These primitives enable complex control flow while maintaining BrainState’s stateful programming model and JAX’s performance benefits.