Debugging Transformed Code#

A plain print(x) inside a jit-compiled function runs only once, at trace time, and shows a tracer rather than a value. To observe actual values — and to set breakpoints — you need tools that execute at runtime, after compilation. This tutorial covers the practical debugging workflow for BrainState code:

  • jax.debug.print for value-time printing, including inside grad and vmap;

  • jax.debug.callback for richer inspection (shapes, statistics);

  • brainstate.transform.breakpoint_if for conditional breakpoints.

import jax
import jax.numpy as jnp

import brainstate
import brainstate.transform as T

brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'

jax.debug.print: values at runtime#

jax.debug.print is compilation-safe: it defers to runtime and prints the concrete value each time the function executes. Use {name} placeholders filled by keyword arguments. Note the prints appear when the function runs, not when it is traced.

@brainstate.transform.jit
def compute(x):
    jax.debug.print('input      = {x}', x=x)
    y = x ** 2
    jax.debug.print('after square = {y}', y=y)
    return jnp.sum(y)

total = compute(jnp.array([1.0, 2.0, 3.0]))
print('returned:', float(total))
input      = [1. 2. 3.]
after square = [1. 4. 9.]
returned: 14.0

Inspecting state updates#

Because the prints execute at runtime, they see the real State values before and after a mutation — invaluable when a buffer drifts or fails to update.

class Accumulator(brainstate.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.total = brainstate.ShortTermState(jnp.zeros(size))

    def __call__(self, x):
        jax.debug.print('before: {s}', s=self.total.value)
        self.total.value = self.total.value + x
        jax.debug.print('after : {s}', s=self.total.value)
        return self.total.value

acc = Accumulator(3)
step = brainstate.transform.jit(acc)
_ = step(jnp.array([1.0, 2.0, 3.0]))
before: [0. 0. 0.]
after : [1. 2. 3.]

Debugging inside grad#

Prints placed in a loss function fire during the forward pass of differentiation, letting you watch the quantities that feed the gradient.

weight = brainstate.ParamState(jnp.array([2.0, 3.0]))

def loss_fn(x):
    pred = weight.value * x
    jax.debug.print('prediction = {p}', p=pred)
    return jnp.sum(pred ** 2)

grads = brainstate.transform.grad(loss_fn, {'w': weight})(jnp.array([0.5, 1.0]))
print('grad:', grads['w'])
prediction = [1. 3.]
grad: [1. 6.]

Debugging inside vmap#

Under vmap the print runs once per batch element, so you can confirm exactly what each lane receives.

def process(x, index):
    jax.debug.print('lane {i}: x={x}', i=index, x=x)
    return x ** 2

batched = brainstate.transform.vmap(process, in_axes=(0, 0))
out = batched(jnp.array([1.0, 2.0, 3.0]), jnp.arange(3))
print('outputs:', out)
lane 0: x=1.0
lane 0: x=1.0
lane 1: x=2.0
lane 2: x=3.0
outputs: [1. 4. 9.]

Richer inspection with jax.debug.callback#

When a one-line print is not enough, jax.debug.callback hands the runtime values to an arbitrary Python function — ideal for logging summary statistics without leaving the compiled region. (The callback must not return a value used by the computation.)

def summarize(name, value):
    print(f'[{name}] shape={value.shape} '
          f'min={float(jnp.min(value)):.3f} '
          f'max={float(jnp.max(value)):.3f} '
          f'mean={float(jnp.mean(value)):.3f}')

@brainstate.transform.jit
def forward(x):
    jax.debug.callback(summarize, 'activations', x)
    return jnp.tanh(x)

_ = forward(brainstate.random.randn(100))
[activations] shape=(100,) min=-2.442 max=2.130 mean=-0.210

Conditional breakpoints with breakpoint_if#

breakpoint_if(pred) drops into JAX’s interactive debugger — but only when pred is true at runtime. This lets you halt on a rare bad condition (a NaN, a negative value) without stopping on every iteration. Here the predicate is never satisfied, so execution proceeds normally; in a real session you would set it to your suspected failure condition and inspect the live values when it triggers.

@brainstate.transform.jit
def guarded(x):
    # Pause for inspection only if a non-finite value appears.
    T.breakpoint_if(jnp.any(~jnp.isfinite(x)))
    return x * 2.0

print('clean input proceeds:', guarded(jnp.array([1.0, 2.0, 3.0])))
clean input proceeds: [2. 4. 6.]

Summary#

  • Ordinary print runs at trace time and shows tracers; use runtime-aware tools instead.

  • jax.debug.print prints concrete values during execution — inside jit, grad, and vmap alike, including before/after State mutations.

  • jax.debug.callback sends values to any Python function for richer inspection.

  • breakpoint_if(pred) opens an interactive debugger only when a condition is met.

See also#