lu#
- class brainunit.lax.lu(x)#
LU decomposition with partial pivoting.
Compute the matrix decomposition \(P \cdot A = L \cdot U\) where \(P\) is a permutation matrix, \(L\) is lower-triangular with unit diagonal, and \(U\) is upper-triangular.
- Parameters:
x (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – A batch of matrices with shape[..., m, n].- Return type:
tuple[saiunit.Quantity |Array,Array,Array]- Returns:
lu (jax.Array or Quantity) – A matrix containing \(L\) in its lower triangle and \(U\) in its upper triangle (the unit diagonal of \(L\) is implicit). If
xhas a unit,lupreserves that unit.pivots (jax.Array) – An
int32array with shape[..., min(m, n)]encoding row swaps.permutation (jax.Array) – An
int32array with shape[..., m]representing the row permutation.
Examples
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.lax as sulax >>> A = jnp.array([[1.0, 2.0], [3.0, 4.0]]) * u.second >>> lu_mat, pivots, perm = sulax.lu(A) >>> u.get_unit(lu_mat) == u.second True