Error Handling and Runtime Checks#
Inside a compiled (jit) function, ordinary Python asserts and if statements that depend on
array values do not work — the values are abstract tracers at trace time. A division by zero,
an out-of-bounds index, or a NaN therefore propagates silently and surfaces much later as a
meaningless result.
brainstate.transform provides JIT-compatible runtime checks built on JAX’s checkify
machinery, extended to understand State. This tutorial covers three tools:
checkify— turn value errors into an explicit error object you can inspect.jit_error_if— raise on a bad condition from inside a compiled step.debug_nan— pinpoint where a NaN or Inf first appears.
import jax
import jax.numpy as jnp
import brainstate
import brainstate.transform as T
brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'
checkify: functionalize runtime checks#
checkify transforms a function so that, instead of raising, it returns an (error, result)
pair. The error object is inert until you ask about it: error.get() returns None when all
checks passed, or the failure message otherwise. Inside the function you assert conditions with
check(pred, msg, *fmt_args).
def safe_log(x):
T.check(jnp.all(x > 0), 'x must be positive, got {}', x)
return jnp.log(x)
checked = T.checkify(safe_log, errors=T.user_checks)
err, out = checked(jnp.array([1.0, 2.0]))
print('valid input -> error:', err.get())
print('valid input -> result:', out)
err, out = checked(jnp.array([-1.0, 2.0]))
print('bad input -> error:', err.get())
valid input -> error: None
valid input -> result: [0. 0.6931472]
bad input -> error: x must be positive, got [-1. 2.] (`check` failed)
To turn a captured error back into a real exception at an outer boundary, call
check_error(error) — it raises if the error is set, and is a no-op otherwise.
Built-in check categories#
Beyond your own checks, checkify can automatically detect whole classes of failures. Select
them by passing a predefined set as errors:
Set |
Detects |
|---|---|
|
your explicit |
|
NaN values produced by any primitive |
|
NaN and Inf from floating-point ops |
|
division by zero |
|
out-of-bounds array indexing |
|
every category above |
No explicit check call is needed — the failure is caught wherever it occurs.
# NaN detection
nan_checked = T.checkify(lambda x: jnp.log(x), errors=T.nan_checks)
err, _ = nan_checked(jnp.array([-1.0]))
print('nan_checks :', err.get())
# Division by zero
div_checked = T.checkify(lambda a, b: a / b, errors=T.div_checks)
err, _ = div_checked(jnp.array(1.0), jnp.array(0.0))
print('div_checks :', err.get())
# Out-of-bounds indexing
idx_checked = T.checkify(lambda arr, i: arr[i], errors=T.index_checks)
err, _ = idx_checked(jnp.arange(3), 10)
print('index_checks:', err.get())
nan_checks : nan generated by primitive: log.
div_checks : division by zero
index_checks: out-of-bounds indexing for array of shape (3,): index 10 is out of bounds for axis 0 with size 3.
jit_error_if: raise from inside a compiled step#
When you would rather fail loudly than thread an error object around, jit_error_if(pred, msg)
raises a runtime error if pred is true. It works inside brainstate.transform.jit and is the
right tool for guarding preconditions in a training or simulation step.
@brainstate.transform.jit
def reciprocal(x):
T.jit_error_if(jnp.any(x == 0.0), 'reciprocal received a zero entry')
return 1.0 / x
print('valid:', reciprocal(jnp.array([2.0, 4.0])))
valid: [0.5 0.25]
If the predicate is ever true at runtime, the call raises with your message instead of returning
inf. The check compiles away to a cheap conditional, so it is safe to leave in production
steps.
debug_nan: locate the first NaN or Inf#
When a model diverges, the hard part is finding where the NaN was born. debug_nan(fn, *args)
runs fn with NaN/Inf detection enabled and raises — naming the offending primitive — the moment
a non-finite value appears.
def unstable(x):
y = x * 1e20
return jnp.exp(y) # overflows to inf
# A finite computation passes through untouched.
T.debug_nan(lambda x: x * 2.0, jnp.array([1.0, 2.0]))
print('finite computation: OK')
try:
T.debug_nan(unstable, jnp.array([10.0]))
except Exception as e:
print('debug_nan caught:', type(e).__name__)
finite computation: OK
debug_nan caught: RuntimeError
Use debug_nan_if(has_nan, fn, *args) to enable the (somewhat costly) detection only when an
upstream flag already suspects trouble, keeping the fast path fast.
Summary#
Value-dependent
assert/ifdo not work insidejit; use the runtime-check tools instead.checkifyreturns(error, result); inspect witherror.get(), assert withcheck(...), re-raise withcheck_error(...).Predefined sets —
user_checks,nan_checks,float_checks,div_checks,index_checks,all_checks— catch whole categories of failures automatically.jit_error_ifraises on a bad condition from inside a compiled step.debug_nan/debug_nan_ifpinpoint the primitive that first produced a NaN or Inf.
See also#
Debugging — printing and inspecting values inside transformed code.
JIT and compilation — why value-dependent control flow is restricted.