lu

Contents

lu#

class saiunit.lax.lu(x)#

LU decomposition with partial pivoting.

Compute the matrix decomposition \(P \cdot A = L \cdot U\) where \(P\) is a permutation matrix, \(L\) is lower-triangular with unit diagonal, and \(U\) is upper-triangular.

Parameters:

x (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – A batch of matrices with shape [..., m, n].

Return type:

tuple[saiunit.Quantity | Array, Array, Array]

Returns:

  • lu (jax.Array or Quantity) – A matrix containing \(L\) in its lower triangle and \(U\) in its upper triangle (the unit diagonal of \(L\) is implicit). If x has a unit, lu preserves that unit.

  • pivots (jax.Array) – An int32 array with shape [..., min(m, n)] encoding row swaps.

  • permutation (jax.Array) – An int32 array with shape [..., m] representing the row permutation.

Examples

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> A = jnp.array([[1.0, 2.0], [3.0, 4.0]]) * u.second
>>> lu_mat, pivots, perm = sulax.lu(A)
>>> u.get_unit(lu_mat) == u.second
True