brainstate.transform.checkify#
- brainstate.transform.checkify(fun, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))#
Functionalize runtime error checks in a stateful function.
A state-aware wrapper over
jax.experimental.checkify.checkify(). The returned function runsfunwith its error checks (NaN, division-by-zero, out-of-bounds, and usercheck()assertions, as selected byerrors) threaded into anErrorvalue instead of raising on the host. This survivesjit/vmap/scanwithout host callbacks.Statereads and writes performed byfunare handled transparently: writes are applied after the call, reads are left unchanged.Unlike
jit_error_if()(a fire-and-forget hostdebug.callback),checkifyis functional and composable: the caller receives theErrorand decides when to inspect (err.get()) or raise (err.throw()).- Parameters:
fun (
Callable) – The function to check. May read and writeStateobjects and callcheck().errors (
Any) – The set of error categories to enable. Use the re-exported setsuser_checks,nan_checks,div_checks,index_checks,float_checks,automatic_checks, orall_checks.
- Returns:
A function with the same signature as
funthat returns a tuple(error, out), whereerroris aErrorandoutisfun’s original output.- Return type:
See also
Notes
funis re-run underjax.experimental.checkify.checkify()(rather than replaying a pre-traced jaxpr) so thatcheckprimitives are emitted directly into checkify’s trace and functionalized natively. The statesfuntouches are discovered once viaStatefulFunction; after the call every state is restored (writes to their new values, reads to their originals) so no tracer leaks into the globalStateobjects.Examples
>>> import brainstate >>> import jax.numpy as jnp >>> >>> def safe_double(x): ... brainstate.transform.check(x > 0, 'x must be positive') ... return x * 2.0 >>> >>> checked = brainstate.transform.checkify(safe_double) >>> err, out = checked(jnp.array(3.0)) >>> err.throw() # no error: does nothing >>> out Array(6., dtype=float32, weak_type=True)