tril_indices_from#
- class saiunit.math.tril_indices_from(arr, k=0)#
Return the indices for the lower-triangle of an
(n, m)array.- Parameters:
- Returns:
out – Row and column indices for the lower triangle.
- Return type:
Tuple[Array,Array]
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> row, col = u.math.tril_indices_from(jnp.ones((3, 3))) >>> row Array([0, 1, 1, 2, 2, 2], dtype=int32)