braintools.tree.take#
- braintools.tree.take(tree, idx, axis, is_leaf=<function is_quantity>)[source]#
Take elements from every leaf of a PyTree along a given axis.
- Parameters:
tree (
Quantity]) – Input PyTree.idx (int, slice or array_like) – Indices used for selection. If a slice, it is applied by standard indexing; otherwise
u.math.takeis used per leaf.axis (
int) – Axis along which to take values.is_leaf (
Callable[[Any],bool] |None) – Predicate to treat a node as a leaf during traversal. Defaults to u.math.is_quantity.
- Returns:
PyTree with elements selected along axis.
- Return type:
PyTree