tree_ones_like

Contents

tree_ones_like#

class brainunit.math.tree_ones_like(tree)#

Create a tree with the same structure as the input, but with ones in each leaf.

Parameters:

tree (pytree) – A JAX-compatible pytree (nested dicts, lists, tuples, etc.) whose leaves are arrays or Quantity objects.

Returns:

out – A tree with the same structure, where every leaf is replaced by a ones-filled array (or Quantity) of the same shape, dtype, and unit.

Return type:

pytree

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> tree = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0])}
>>> u.math.tree_ones_like(tree)
{'a': Array([1., 1.], dtype=float32), 'b': Array([1.], dtype=float32)}