braintools.tree.expand

Contents

braintools.tree.expand#

braintools.tree.expand(tree, axis, is_leaf=<function is_quantity>)[source]#

Insert a length-1 axis into every leaf of a PyTree.

Parameters:
  • tree (Quantity]) – Input PyTree.

  • axis (int) – Position in the expanded shape where the new axis is placed.

  • is_leaf (Callable[[Any], bool] | None) – Predicate to treat a node as a leaf during traversal. Defaults to u.math.is_quantity.

Returns:

PyTree with expand_dims(x, axis) applied to each leaf x.

Return type:

PyTree