braintools.tree module

braintools.tree module#

Operations for manipulating JAX PyTrees, including arithmetic, reshaping, and conversion utilities used across the toolkit.

scale(tree, x[, is_leaf])

Multiply every leaf array in a PyTree by a value.

mul(tree, x[, is_leaf])

Elementwise multiplication of a PyTree with a value or another PyTree.

shift(tree1, x[, is_leaf])

Add a value to every leaf array in a PyTree.

add(tree1, tree2[, is_leaf])

Elementwise addition of PyTrees or a PyTree and a value.

sub(tree1, tree2[, is_leaf])

Elementwise subtraction of two PyTrees.

dot(a, b[, is_leaf])

Inner product over all leaves of two PyTrees.

sum(tree[, is_leaf])

Sum all elements across every leaf in a PyTree.

squared_norm(tree[, is_leaf])

Sum of squares of all elements in a PyTree.

concat(trees[, axis, is_leaf])

Concatenate corresponding leaves from a sequence of PyTrees.

split(tree, sizes[, is_leaf])

Split each leaf of a PyTree along axis 0 according to sizes.

idx(tree, idx[, is_leaf])

Index every leaf of a PyTree.

expand(tree, axis[, is_leaf])

Insert a length-1 axis into every leaf of a PyTree.

take(tree, idx, axis[, is_leaf])

Take elements from every leaf of a PyTree along a given axis.

as_numpy(tree[, is_leaf])

Convert all leaves of a PyTree to NumPy arrays.