Debugging Transformed Code#
A plain print(x) inside a jit-compiled function runs only once, at trace time, and shows a
tracer rather than a value. To observe actual values — and to set breakpoints — you need tools
that execute at runtime, after compilation. This tutorial covers the practical debugging
workflow for BrainState code:
jax.debug.printfor value-time printing, including insidegradandvmap;jax.debug.callbackfor richer inspection (shapes, statistics);brainstate.transform.breakpoint_iffor conditional breakpoints.
import jax
import jax.numpy as jnp
import brainstate
import brainstate.transform as T
brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'
jax.debug.print: values at runtime#
jax.debug.print is compilation-safe: it defers to runtime and prints the concrete value each
time the function executes. Use {name} placeholders filled by keyword arguments. Note the
prints appear when the function runs, not when it is traced.
@brainstate.transform.jit
def compute(x):
jax.debug.print('input = {x}', x=x)
y = x ** 2
jax.debug.print('after square = {y}', y=y)
return jnp.sum(y)
total = compute(jnp.array([1.0, 2.0, 3.0]))
print('returned:', float(total))
input = [1. 2. 3.]
after square = [1. 4. 9.]
returned: 14.0
Inspecting state updates#
Because the prints execute at runtime, they see the real State values before and after a
mutation — invaluable when a buffer drifts or fails to update.
class Accumulator(brainstate.nn.Module):
def __init__(self, size):
super().__init__()
self.total = brainstate.ShortTermState(jnp.zeros(size))
def __call__(self, x):
jax.debug.print('before: {s}', s=self.total.value)
self.total.value = self.total.value + x
jax.debug.print('after : {s}', s=self.total.value)
return self.total.value
acc = Accumulator(3)
step = brainstate.transform.jit(acc)
_ = step(jnp.array([1.0, 2.0, 3.0]))
before: [0. 0. 0.]
after : [1. 2. 3.]
Debugging inside grad#
Prints placed in a loss function fire during the forward pass of differentiation, letting you watch the quantities that feed the gradient.
weight = brainstate.ParamState(jnp.array([2.0, 3.0]))
def loss_fn(x):
pred = weight.value * x
jax.debug.print('prediction = {p}', p=pred)
return jnp.sum(pred ** 2)
grads = brainstate.transform.grad(loss_fn, {'w': weight})(jnp.array([0.5, 1.0]))
print('grad:', grads['w'])
prediction = [1. 3.]
grad: [1. 6.]
Debugging inside vmap#
Under vmap the print runs once per batch element, so you can confirm exactly what each lane
receives.
def process(x, index):
jax.debug.print('lane {i}: x={x}', i=index, x=x)
return x ** 2
batched = brainstate.transform.vmap(process, in_axes=(0, 0))
out = batched(jnp.array([1.0, 2.0, 3.0]), jnp.arange(3))
print('outputs:', out)
lane 0: x=1.0
lane 0: x=1.0
lane 1: x=2.0
lane 2: x=3.0
outputs: [1. 4. 9.]
Richer inspection with jax.debug.callback#
When a one-line print is not enough, jax.debug.callback hands the runtime values to an
arbitrary Python function — ideal for logging summary statistics without leaving the compiled
region. (The callback must not return a value used by the computation.)
def summarize(name, value):
print(f'[{name}] shape={value.shape} '
f'min={float(jnp.min(value)):.3f} '
f'max={float(jnp.max(value)):.3f} '
f'mean={float(jnp.mean(value)):.3f}')
@brainstate.transform.jit
def forward(x):
jax.debug.callback(summarize, 'activations', x)
return jnp.tanh(x)
_ = forward(brainstate.random.randn(100))
[activations] shape=(100,) min=-2.442 max=2.130 mean=-0.210
Conditional breakpoints with breakpoint_if#
breakpoint_if(pred) drops into JAX’s interactive debugger — but only when pred is true at
runtime. This lets you halt on a rare bad condition (a NaN, a negative value) without stopping on
every iteration. Here the predicate is never satisfied, so execution proceeds normally; in a real
session you would set it to your suspected failure condition and inspect the live values when it
triggers.
@brainstate.transform.jit
def guarded(x):
# Pause for inspection only if a non-finite value appears.
T.breakpoint_if(jnp.any(~jnp.isfinite(x)))
return x * 2.0
print('clean input proceeds:', guarded(jnp.array([1.0, 2.0, 3.0])))
clean input proceeds: [2. 4. 6.]
Summary#
Ordinary
printruns at trace time and shows tracers; use runtime-aware tools instead.jax.debug.printprints concrete values during execution — insidejit,grad, andvmapalike, including before/afterStatemutations.jax.debug.callbacksends values to any Python function for richer inspection.breakpoint_if(pred)opens an interactive debugger only when a condition is met.
See also#
Error handling and runtime checks — catching NaNs and bad inputs.
IR optimization and code generation — inspecting the compiled jaxpr.