brainstate.transform.breakpoint_if

Contents

brainstate.transform.breakpoint_if#

brainstate.transform.breakpoint_if(pred, **breakpoint_kwargs)[source]#

As jax.debug.breakpoint, but only triggers if pred is True.

Parameters:
  • pred (bool or jax.Array) – Predicate for whether to trigger the breakpoint.

  • **breakpoint_kwargs – Forwarded to jax.debug.breakpoint.