braintools.tree.split

Contents

braintools.tree.split#

braintools.tree.split(tree, sizes, is_leaf=<function is_quantity>)[source]#

Split each leaf of a PyTree along axis 0 according to sizes.

Parameters:
  • tree (Array]) – Input PyTree. Each leaf is sliced along its first axis.

  • sizes (Tuple[int]) – Sizes for consecutive splits along axis 0. The remainder (if any) after sum(sizes) is returned as a final chunk.

  • is_leaf (Callable[[Any], bool] | None) – Predicate to treat a node as a leaf during traversal. Defaults to u.math.is_quantity.

Returns:

Tuple of length len(sizes) + 1 where each element is a PyTree containing the corresponding slice. The last PyTree may be empty if sum(sizes) equals the size along axis 0.

Return type:

Tuple[Array], ...]

Notes

This function operates on axis 0 only.