diag_indices_from

diag_indices_from#

class brainunit.math.diag_indices_from(arr, **kwargs)#

Return indices for accessing the main diagonal of a given array.

Units are stripped before computing the indices.

Parameters:

arr (Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity) – Input array. Must be at least 2-D with equal-length dimensions.

Returns:

indices – Index arrays to access the main diagonal.

Return type:

tuple[Array, ...]

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> arr = jnp.array([[1, 2], [3, 4]])
>>> u.math.diag_indices_from(arr)
(Array([0, 1], dtype=int32), Array([0, 1], dtype=int32))