tree_zeros_like

Contents

tree_zeros_like#

class brainunit.math.tree_zeros_like(tree)#

Create a tree with the same structure as the input, but with zeros in each leaf.

Parameters:

tree (pytree) – A JAX-compatible pytree (nested dicts, lists, tuples, etc.) whose leaves are arrays or Quantity objects.

Returns:

out – A tree with the same structure, where every leaf is replaced by a zero-filled array (or Quantity) of the same shape, dtype, and unit.

Return type:

pytree

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> tree = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0])}
>>> u.math.tree_zeros_like(tree)
{'a': Array([0., 0.], dtype=float32), 'b': Array([0.], dtype=float32)}