tree_ones_like#
- class brainunit.math.tree_ones_like(tree)#
Create a tree with the same structure as the input, but with ones in each leaf.
- Parameters:
tree (pytree) – A JAX-compatible pytree (nested dicts, lists, tuples, etc.) whose leaves are arrays or
Quantityobjects.- Returns:
out – A tree with the same structure, where every leaf is replaced by a ones-filled array (or
Quantity) of the same shape, dtype, and unit.- Return type:
pytree
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> tree = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0])} >>> u.math.tree_ones_like(tree) {'a': Array([1., 1.], dtype=float32), 'b': Array([1.], dtype=float32)}