braintools.tree module#
Operations for manipulating JAX PyTrees, including arithmetic, reshaping, and conversion utilities used across the toolkit.
|
Multiply every leaf array in a PyTree by a value. |
|
Elementwise multiplication of a PyTree with a value or another PyTree. |
|
Add a value to every leaf array in a PyTree. |
|
Elementwise addition of PyTrees or a PyTree and a value. |
|
Elementwise subtraction of two PyTrees. |
|
Inner product over all leaves of two PyTrees. |
|
Sum all elements across every leaf in a PyTree. |
|
Sum of squares of all elements in a PyTree. |
|
Concatenate corresponding leaves from a sequence of PyTrees. |
|
Split each leaf of a PyTree along axis 0 according to sizes. |
|
Index every leaf of a PyTree. |
|
Insert a length-1 axis into every leaf of a PyTree. |
|
Take elements from every leaf of a PyTree along a given axis. |
|
Convert all leaves of a PyTree to NumPy arrays. |