trace#
- class brainunit.linalg.trace(x, *, offset=0, dtype=None, **kwargs)#
Compute the trace of a matrix or stack of matrices.
SaiUnit implementation of
numpy.linalg.trace().Unlike
saiunit.math.trace()(which followsnumpy.trace()and sums over the first two axes by default), this function followslinalgsemantics and sums along the diagonals of the last two axes.- Parameters:
- Returns:
out – Trace of shape
x.shape[:-2]. Carries the same unit as x.- Return type:
Array| saiunit.Quantity
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> x = jnp.arange(18.).reshape(2, 3, 3) * u.meter >>> u.linalg.trace(x) ArrayImpl([12., 39.], dtype=float32) * meter