brainstate.transform.checkify

Contents

brainstate.transform.checkify#

brainstate.transform.checkify(fun, errors=frozenset({<class 'jax._src.checkify.FailedCheckError'>}))#

Functionalize runtime error checks in a stateful function.

A state-aware wrapper over jax.experimental.checkify.checkify(). The returned function runs fun with its error checks (NaN, division-by-zero, out-of-bounds, and user check() assertions, as selected by errors) threaded into an Error value instead of raising on the host. This survives jit/vmap/scan without host callbacks. State reads and writes performed by fun are handled transparently: writes are applied after the call, reads are left unchanged.

Unlike jit_error_if() (a fire-and-forget host debug.callback), checkify is functional and composable: the caller receives the Error and decides when to inspect (err.get()) or raise (err.throw()).

Parameters:
Returns:

A function with the same signature as fun that returns a tuple (error, out), where error is a Error and out is fun’s original output.

Return type:

Callable

Notes

fun is re-run under jax.experimental.checkify.checkify() (rather than replaying a pre-traced jaxpr) so that check primitives are emitted directly into checkify’s trace and functionalized natively. The states fun touches are discovered once via StatefulFunction; after the call every state is restored (writes to their new values, reads to their originals) so no tracer leaks into the global State objects.

Examples

>>> import brainstate
>>> import jax.numpy as jnp
>>>
>>> def safe_double(x):
...     brainstate.transform.check(x > 0, 'x must be positive')
...     return x * 2.0
>>>
>>> checked = brainstate.transform.checkify(safe_double)
>>> err, out = checked(jnp.array(3.0))
>>> err.throw()   # no error: does nothing
>>> out
Array(6., dtype=float32, weak_type=True)