braintools.tree.take

Contents

braintools.tree.take#

braintools.tree.take(tree, idx, axis, is_leaf=<function is_quantity>)[source]#

Take elements from every leaf of a PyTree along a given axis.

Parameters:
  • tree (Quantity]) – Input PyTree.

  • idx (int, slice or array_like) – Indices used for selection. If a slice, it is applied by standard indexing; otherwise u.math.take is used per leaf.

  • axis (int) – Axis along which to take values.

  • is_leaf (Callable[[Any], bool] | None) – Predicate to treat a node as a leaf during traversal. Defaults to u.math.is_quantity.

Returns:

PyTree with elements selected along axis.

Return type:

PyTree