brainstate.transform.cond#
- brainstate.transform.cond(pred, true_fun, false_fun, *operands)#
Conditionally apply
true_funorfalse_fun.- Parameters:
pred (bool or array-like) – Boolean scalar selecting which branch to execute. Numeric inputs are treated as
Truewhen non-zero.true_fun (
Callable) – Function that receives*operandswhenpredisTrue.false_fun (
Callable) – Function that receives*operandswhenpredisFalse.*operands (Any) – Operands forwarded to either branch. May be any pytree of arrays, scalars, or nested containers thereof.
- Returns:
Value returned by the selected branch with the same pytree structure as produced by
true_funorfalse_fun.- Return type:
Any
Notes
Provided the arguments are correctly typed,
cond()has semantics that match the following Python implementation, wherepredmust be a scalar:>>> def cond(pred, true_fun, false_fun, *operands): ... if pred: ... return true_fun(*operands) ... return false_fun(*operands)
In contrast with
jax.lax.select(), usingcond()indicates that only one branch runs (subject to compiler rewrites and optimizations). When transformed withvmap()over a batch of predicates,cond()is converted toselect().Examples
>>> import brainstate >>> >>> def branch_true(x): ... return x + 1 >>> >>> def branch_false(x): ... return x - 1 >>> >>> brainstate.transform.cond(True, branch_true, branch_false, 3)