braintools.tree.split#
- braintools.tree.split(tree, sizes, is_leaf=<function is_quantity>)[source]#
Split each leaf of a PyTree along axis 0 according to sizes.
- Parameters:
tree (
Array]) – Input PyTree. Each leaf is sliced along its first axis.sizes (
Tuple[int]) – Sizes for consecutive splits along axis 0. The remainder (if any) aftersum(sizes)is returned as a final chunk.is_leaf (
Callable[[Any],bool] |None) – Predicate to treat a node as a leaf during traversal. Defaults to u.math.is_quantity.
- Returns:
Tuple of length
len(sizes) + 1where each element is a PyTree containing the corresponding slice. The last PyTree may be empty ifsum(sizes)equals the size along axis 0.- Return type:
Notes
This function operates on axis 0 only.