slogdet#
- class saiunit.linalg.slogdet(a, *, method=None, **kwargs)#
Compute the sign and (natural) logarithm of the absolute determinant.
SaiUnit implementation of
numpy.linalg.slogdet().The unit is stripped before computation. Both returned arrays are always dimensionless. This function is more numerically stable than computing
log(det(a))directly because it avoids overflow and underflow for matrices with very large or very small determinants.- Parameters:
a (
Array|ndarray|bool|number|bool|int|float|complex| saiunit.Quantity) – Square input of shape(..., M, M)for which to compute the sign and log-determinant. If a carries a unit, the unit is removed before computation.Decomposition method used internally.
'lu'(default) – use the LU decomposition.'qr'– use the QR decomposition.
- Return type:
tuple[Array,Array]- Returns:
sign (jax.Array) – Sign of the determinant (
+1.,-1., or0.), of shapea.shape[:-2].logabsdet (jax.Array) – Natural logarithm of the absolute value of the determinant, of shape
a.shape[:-2].
See also
saiunit.linalg.condCondition number of a matrix.
saiunit.linalg.matrix_rankRank of a matrix via SVD.
Notes
The determinant can be reconstructed as
sign * exp(logabsdet). Usingslogdet()instead ofdet()avoids numerical issues when the determinant is extremely large or small.Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> a = jnp.array([[1., 2.], ... [3., 4.]]) * u.meter >>> sign, logabsdet = u.linalg.slogdet(a) >>> sign Array(-1., dtype=float32) >>> jnp.exp(logabsdet) Array(2., dtype=float32)