Error Handling and Runtime Checks#

Inside a compiled (jit) function, ordinary Python asserts and if statements that depend on array values do not work — the values are abstract tracers at trace time. A division by zero, an out-of-bounds index, or a NaN therefore propagates silently and surfaces much later as a meaningless result.

brainstate.transform provides JIT-compatible runtime checks built on JAX’s checkify machinery, extended to understand State. This tutorial covers three tools:

  • checkify — turn value errors into an explicit error object you can inspect.

  • jit_error_if — raise on a bad condition from inside a compiled step.

  • debug_nan — pinpoint where a NaN or Inf first appears.

import jax
import jax.numpy as jnp

import brainstate
import brainstate.transform as T

brainstate.random.seed(0)
brainstate.__version__
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
'0.4.0'

checkify: functionalize runtime checks#

checkify transforms a function so that, instead of raising, it returns an (error, result) pair. The error object is inert until you ask about it: error.get() returns None when all checks passed, or the failure message otherwise. Inside the function you assert conditions with check(pred, msg, *fmt_args).

def safe_log(x):
    T.check(jnp.all(x > 0), 'x must be positive, got {}', x)
    return jnp.log(x)

checked = T.checkify(safe_log, errors=T.user_checks)

err, out = checked(jnp.array([1.0, 2.0]))
print('valid input -> error:', err.get())
print('valid input -> result:', out)

err, out = checked(jnp.array([-1.0, 2.0]))
print('bad input   -> error:', err.get())
valid input -> error: None
valid input -> result: [0.        0.6931472]
bad input   -> error: x must be positive, got [-1.  2.] (`check` failed)

To turn a captured error back into a real exception at an outer boundary, call check_error(error) — it raises if the error is set, and is a no-op otherwise.

Built-in check categories#

Beyond your own checks, checkify can automatically detect whole classes of failures. Select them by passing a predefined set as errors:

Set

Detects

user_checks

your explicit check(...) assertions

nan_checks

NaN values produced by any primitive

float_checks

NaN and Inf from floating-point ops

div_checks

division by zero

index_checks

out-of-bounds array indexing

all_checks

every category above

No explicit check call is needed — the failure is caught wherever it occurs.

# NaN detection
nan_checked = T.checkify(lambda x: jnp.log(x), errors=T.nan_checks)
err, _ = nan_checked(jnp.array([-1.0]))
print('nan_checks  :', err.get())

# Division by zero
div_checked = T.checkify(lambda a, b: a / b, errors=T.div_checks)
err, _ = div_checked(jnp.array(1.0), jnp.array(0.0))
print('div_checks  :', err.get())

# Out-of-bounds indexing
idx_checked = T.checkify(lambda arr, i: arr[i], errors=T.index_checks)
err, _ = idx_checked(jnp.arange(3), 10)
print('index_checks:', err.get())
nan_checks  : nan generated by primitive: log.
div_checks  : division by zero
index_checks: out-of-bounds indexing for array of shape (3,): index 10 is out of bounds for axis 0 with size 3. 

jit_error_if: raise from inside a compiled step#

When you would rather fail loudly than thread an error object around, jit_error_if(pred, msg) raises a runtime error if pred is true. It works inside brainstate.transform.jit and is the right tool for guarding preconditions in a training or simulation step.

@brainstate.transform.jit
def reciprocal(x):
    T.jit_error_if(jnp.any(x == 0.0), 'reciprocal received a zero entry')
    return 1.0 / x

print('valid:', reciprocal(jnp.array([2.0, 4.0])))
valid: [0.5  0.25]

If the predicate is ever true at runtime, the call raises with your message instead of returning inf. The check compiles away to a cheap conditional, so it is safe to leave in production steps.

debug_nan: locate the first NaN or Inf#

When a model diverges, the hard part is finding where the NaN was born. debug_nan(fn, *args) runs fn with NaN/Inf detection enabled and raises — naming the offending primitive — the moment a non-finite value appears.

def unstable(x):
    y = x * 1e20
    return jnp.exp(y)   # overflows to inf

# A finite computation passes through untouched.
T.debug_nan(lambda x: x * 2.0, jnp.array([1.0, 2.0]))
print('finite computation: OK')

try:
    T.debug_nan(unstable, jnp.array([10.0]))
except Exception as e:
    print('debug_nan caught:', type(e).__name__)
finite computation: OK
debug_nan caught: RuntimeError

Use debug_nan_if(has_nan, fn, *args) to enable the (somewhat costly) detection only when an upstream flag already suspects trouble, keeping the fast path fast.

Summary#

  • Value-dependent assert/if do not work inside jit; use the runtime-check tools instead.

  • checkify returns (error, result); inspect with error.get(), assert with check(...), re-raise with check_error(...).

  • Predefined sets — user_checks, nan_checks, float_checks, div_checks, index_checks, all_checks — catch whole categories of failures automatically.

  • jit_error_if raises on a bad condition from inside a compiled step.

  • debug_nan / debug_nan_if pinpoint the primitive that first produced a NaN or Inf.

See also#

  • Debugging — printing and inspecting values inside transformed code.

  • JIT and compilation — why value-dependent control flow is restricted.