brainstate.transform.jit_error_if#
- brainstate.transform.jit_error_if(pred, error, *err_args, **err_kwargs)#
Check errors in a jit function.
- Parameters:
Examples
It can give a function which receive arguments that passed from the JIT variables and raise errors.
>>> def error(x): ... raise ValueError(f'error {x}') >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,)) >>> jit_error_if(x.sum() < 5., error, x)
Or, it can be a simple string message.
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,)) >>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())