brainstate.transform.switch#
- brainstate.transform.switch(index, branches, *operands)#
Apply exactly one branch from
branchesbased onindex.- Parameters:
- Returns:
Value returned by the selected branch with the same pytree structure as the selected callable.
- Return type:
Any
Notes
If
indexis out of bounds, it is clamped to[0, len(branches) - 1]. Conceptually,switch()behaves like:>>> def switch(index, branches, *operands): ... safe_index = clamp(0, index, len(branches) - 1) ... return branches[safe_index](*operands)
Internally this wraps XLA’s Conditional operator. When transformed with
vmap()over a batch of predicates,switch()is converted toselect().Examples
>>> import brainstate >>> >>> branches = ( ... lambda x: x - 1, ... lambda x: x, ... lambda x: x + 1, ... ) >>> >>> brainstate.transform.switch(2, branches, 3)