argwhere

Contents

argwhere#

class saiunit.math.argwhere(a, *, size=None, fill_value=None, **kwargs)#

Find the indices of array elements that are non-zero.

Units are stripped before the search.

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity) – Input data.

  • size (int | None) – Fixed output size (for use inside jax.jit).

  • fill_value (Array | ndarray | bool | number | bool | int | float | complex | None) – Fill value for padding when size is given.

Returns:

indices – Array of shape (N, a.ndim) containing the indices of non-zero elements.

Return type:

Array

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> u.math.argwhere(jnp.array([0, 1, 0, 2]), size=2)
Array([[1],
       [3]], dtype=int32)