brainstate.transform.switch

Contents

brainstate.transform.switch#

brainstate.transform.switch(index, branches, *operands)#

Apply exactly one branch from branches based on index.

Parameters:
  • index (int or array-like) – Scalar integer specifying which branch to execute.

  • branches (Sequence[Callable]) – Sequence of callables; each receives *operands.

  • *operands (Any) – Operands forwarded to the selected branch. May be any pytree of arrays, scalars, or nested containers thereof.

Returns:

Value returned by the selected branch with the same pytree structure as the selected callable.

Return type:

Any

Notes

If index is out of bounds, it is clamped to [0, len(branches) - 1]. Conceptually, switch() behaves like:

>>> def switch(index, branches, *operands):
...     safe_index = clamp(0, index, len(branches) - 1)
...     return branches[safe_index](*operands)

Internally this wraps XLA’s Conditional operator. When transformed with vmap() over a batch of predicates, switch() is converted to select().

Examples

>>> import brainstate
>>>
>>> branches = (
...     lambda x: x - 1,
...     lambda x: x,
...     lambda x: x + 1,
... )
>>>
>>> brainstate.transform.switch(2, branches, 3)