brainstate.transform.jit_error_if

Contents

brainstate.transform.jit_error_if#

brainstate.transform.jit_error_if(pred, error, *err_args, **err_kwargs)#

Check errors in a jit function.

Parameters:
  • pred (bool or Array) – The boolean prediction.

  • error (Callable | str) – The error function, which raise errors, or a string indicating the error message.

  • *err_args – The arguments which passed into the error function.

  • **err_kwargs – The keywords which passed into the error function.

Examples

It can give a function which receive arguments that passed from the JIT variables and raise errors.

>>> def error(x):
...     raise ValueError(f'error {x}')
>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
>>> jit_error_if(x.sum() < 5., error, x)

Or, it can be a simple string message.

>>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
>>> jit_error_if(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())