braintools.tree.concat

Contents

braintools.tree.concat#

braintools.tree.concat(trees, axis=0, is_leaf=<function is_quantity>)[source]#

Concatenate corresponding leaves from a sequence of PyTrees.

Parameters:
  • trees (Sequence[Quantity]]) – PyTrees with identical structure whose leaves are concatenated.

  • axis (int) – Axis along which to concatenate leaves.

  • is_leaf (Callable[[Any], bool] | None) – Predicate to treat a node as a leaf during traversal. Defaults to u.math.is_quantity.

Returns:

PyTree with each leaf given by concatenate([t[i] for t in trees], axis=axis).

Return type:

PyTree

Notes

All PyTrees in trees must share the same PyTree structure, and their corresponding leaves must be compatible for concatenation along axis.