qdwh

Contents

qdwh#

class brainunit.lax.qdwh(x)#

Polar decomposition via QR-based dynamically weighted Halley iteration.

Compute the polar decomposition \(x = U \cdot H\) where \(U\) is unitary and \(H\) is Hermitian positive semi-definite.

Parameters:

x (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – A full-rank matrix with shape (M, N).

Return type:

tuple[Array, saiunit.Quantity | Array, int, bool]

Returns:

  • u (jax.Array) – The unitary factor (unitless).

  • h (jax.Array or Quantity) – The Hermitian positive semi-definite factor. If x has a unit, h preserves that unit.

  • num_iters (int) – Number of iterations performed.

  • is_converged (bool) – True if the algorithm converged within the maximum number of iterations.

Examples

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> A = jnp.array([[1.0, 2.0], [3.0, 4.0]]) * u.second
>>> u, h, num_iters, is_converged = sulax.qdwh(A)
>>> u.get_unit(h) == u.second
True