tree_zeros_like#
- class brainunit.math.tree_zeros_like(tree)#
Create a tree with the same structure as the input, but with zeros in each leaf.
- Parameters:
tree (pytree) – A JAX-compatible pytree (nested dicts, lists, tuples, etc.) whose leaves are arrays or
Quantityobjects.- Returns:
out – A tree with the same structure, where every leaf is replaced by a zero-filled array (or
Quantity) of the same shape, dtype, and unit.- Return type:
pytree
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> tree = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0])} >>> u.math.tree_zeros_like(tree) {'a': Array([0., 0.], dtype=float32), 'b': Array([0.], dtype=float32)}