svd

Contents

svd#

class saiunit.lax.svd(x, *, full_matrices=True, compute_uv=True, subset_by_index=None, algorithm=None)#

Singular value decomposition.

Compute the SVD of a matrix. When compute_uv is True, return (u, s, vh); otherwise return only the singular values s.

Parameters:
  • x (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – A batch of matrices with shape [..., m, n].

  • full_matrices (bool) – If True, compute full-size U and Vh. Default is True.

  • compute_uv (bool) – If True, compute U and Vh in addition to S. Default is True.

  • subset_by_index (tuple[int, int] | None) – Optional (start, end) range of singular-value indices. Default is None.

  • algorithm (SvdAlgorithm | None) – The SVD algorithm to use. Default is None.

Return type:

saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex | tuple[Array, saiunit.Quantity | Array, Array]

Returns:

  • u (jax.Array) – Left singular vectors (unitless). Only returned when compute_uv=True.

  • s (jax.Array or Quantity) – Singular values. If x has a unit, s preserves that unit.

  • vh (jax.Array) – Right singular vectors (unitless). Only returned when compute_uv=True.

Examples

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> A = jnp.array([[1.0, 2.0], [3.0, 4.0]]) * u.meter
>>> u, s, vh = sulax.svd(A)
>>> u.get_unit(s) == u.meter
True