trace

Contents

trace#

class brainunit.linalg.trace(x, *, offset=0, dtype=None, **kwargs)#

Compute the trace of a matrix or stack of matrices.

SaiUnit implementation of numpy.linalg.trace().

Unlike saiunit.math.trace() (which follows numpy.trace() and sums over the first two axes by default), this function follows linalg semantics and sums along the diagonals of the last two axes.

Parameters:
  • x (Array | ndarray | number | bool | saiunit.Quantity) – Input of shape (..., M, N).

  • offset (int) – Offset of the diagonal from the main diagonal (default: 0).

  • dtype (str | type[Any] | dtype | SupportsDType | None) – Data type of the returned array.

Returns:

out – Trace of shape x.shape[:-2]. Carries the same unit as x.

Return type:

Array | saiunit.Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> x = jnp.arange(18.).reshape(2, 3, 3) * u.meter
>>> u.linalg.trace(x)
ArrayImpl([12., 39.], dtype=float32) * meter