Debugging and Error Checking

Debugging and Error Checking#

JIT-compatible debugging utilities for identifying NaN and Inf values during gradient computations, plus functionalized runtime error checking. These tools help diagnose numerical issues in compiled code without sacrificing performance.

NaN/Inf Debugging#

debug_nan

Run fn with NaN / Inf detection (JIT-compatible).

debug_nan_if

Conditionally run fn with NaN / Inf detection.

breakpoint_if

As jax.debug.breakpoint, but only triggers if pred is True.

Error Checking#

Performs conditional checks during JIT compilation and raises an error if the specified condition is met, helping catch exceptional cases at compile or run time.

jit_error_if

Check errors in a jit function.

checkify

Functionalize runtime error checks in a stateful function.

check

Assert a runtime condition inside a checkify()-wrapped function.

check_error

Re-raise a previously captured Error inside a checkified context.

all_checks

Build an immutable unordered collection of unique elements.

user_checks

Build an immutable unordered collection of unique elements.

nan_checks

Build an immutable unordered collection of unique elements.

div_checks

Build an immutable unordered collection of unique elements.

index_checks

Build an immutable unordered collection of unique elements.

float_checks

Build an immutable unordered collection of unique elements.

automatic_checks

Build an immutable unordered collection of unique elements.