brainstate.transform.check

Contents

brainstate.transform.check#

brainstate.transform.check(pred, msg, *fmt_args, debug=False, **fmt_kwargs)#

Assert a runtime condition inside a checkify()-wrapped function.

A state-transparent re-export of jax.experimental.checkify.check(). When the wrapping function is transformed by checkify(), a false pred is functionalized into the threaded Error rather than raising immediately, so the check survives jit/vmap/scan.

Parameters:
  • pred (Array | ndarray | bool | number | bool | int | float | complex | Quantity) – The condition that must hold. False records an error.

  • msg (str) – The error message. May contain {}/{name} fields filled from fmt_args/fmt_kwargs (traced values are allowed).

  • *fmt_args – Positional format arguments for msg.

  • debug (bool) – If True, the check is treated as a debug-only check.

  • **fmt_kwargs – Keyword format arguments for msg.

See also

checkify, check_error

Return type:

None