tridiagonal

Contents

tridiagonal#

class saiunit.lax.tridiagonal(a, lower=True)#

Reduce a symmetric/Hermitian matrix to tridiagonal form.

Currently implemented on CPU and GPU only.

Parameters:
  • a (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – A floating-point or complex symmetric/Hermitian matrix (or batch of matrices) with shape [..., n, n].

  • lower (bool) – Selects which triangle of the input to use. Default is True.

Return type:

tuple[saiunit.Quantity | Array, saiunit.Quantity | Array, saiunit.Quantity | Array, Array]

Returns:

  • a_out (jax.Array or Quantity) – The matrix with the tridiagonal representation stored in its diagonal and first sub/super-diagonal; remaining elements hold the Householder reflectors. If a has a unit, a_out preserves that unit.

  • d (jax.Array or Quantity) – The diagonal of the tridiagonal matrix. Preserves the unit of a if present.

  • e (jax.Array or Quantity) – The first sub-diagonal (lower=True) or super-diagonal (lower=False). Preserves the unit of a if present.

  • taus (jax.Array) – Scalar factors of the elementary Householder reflectors (unitless).

Examples

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> A = jnp.array([[2.0, 1.0], [1.0, 3.0]]) * u.second
>>> a_out, d, e, taus = sulax.tridiagonal(A)
>>> u.get_unit(d) == u.second
True