brainstate.transform.cond

Contents

brainstate.transform.cond#

brainstate.transform.cond(pred, true_fun, false_fun, *operands)#

Conditionally apply true_fun or false_fun.

Parameters:
  • pred (bool or array-like) – Boolean scalar selecting which branch to execute. Numeric inputs are treated as True when non-zero.

  • true_fun (Callable) – Function that receives *operands when pred is True.

  • false_fun (Callable) – Function that receives *operands when pred is False.

  • *operands (Any) – Operands forwarded to either branch. May be any pytree of arrays, scalars, or nested containers thereof.

Returns:

Value returned by the selected branch with the same pytree structure as produced by true_fun or false_fun.

Return type:

Any

Notes

Provided the arguments are correctly typed, cond() has semantics that match the following Python implementation, where pred must be a scalar:

>>> def cond(pred, true_fun, false_fun, *operands):
...     if pred:
...         return true_fun(*operands)
...     return false_fun(*operands)

In contrast with jax.lax.select(), using cond() indicates that only one branch runs (subject to compiler rewrites and optimizations). When transformed with vmap() over a batch of predicates, cond() is converted to select().

Examples

>>> import brainstate
>>>
>>> def branch_true(x):
...     return x + 1
>>>
>>> def branch_false(x):
...     return x - 1
>>>
>>> brainstate.transform.cond(True, branch_true, branch_false, 3)