brainstate.transform.debug_nan_if

Contents

brainstate.transform.debug_nan_if#

brainstate.transform.debug_nan_if(has_nan, fn, *args, phase='')[source]#

Conditionally run fn with on-device NaN / Inf detection.

Equivalent to:

if has_nan:
    debug_nan(fn, *args, phase=phase)

but JIT-compatible via jax.lax.cond.

Parameters:
  • has_nan (bool or jax.Array) – Condition to trigger debugging.

  • fn (Callable) – The function to debug.

  • *args – Arguments to pass to the function.

  • phase (str) – Label prepended to the error message.