triangular_solve

Contents

triangular_solve#

class brainunit.lax.triangular_solve(a, b, left_side=False, lower=False, transpose_a=False, conjugate_a=False, unit_diagonal=False)#

Triangular solve.

Solve the matrix equation \(\mathit{op}(A) \cdot X = B\) (when left_side=True) or \(X \cdot \mathit{op}(A) = B\) (when left_side=False), where \(A\) is triangular.

Parameters:
  • a (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – A batch of triangular matrices with shape [..., m, m].

  • b (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – A batch of right-hand-side matrices with shape [..., m, n] (if left_side=True) or [..., n, m] otherwise.

  • left_side (bool) – Selects which equation to solve. Default is False.

  • lower (bool) – If True, use the lower triangle of a. Default is False.

  • transpose_a (bool) – If True, transpose a before solving. Default is False.

  • conjugate_a (bool) – If True, use the complex conjugate of a. Default is False.

  • unit_diagonal (bool) – If True, the diagonal of a is assumed to be all ones. Default is False.

Returns:

X – The solution with the same shape and dtype as b. If b carries a unit, X preserves that unit.

Return type:

saiunit.Quantity | Array

Examples

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> A = jnp.array([[2.0, 0.0], [1.0, 3.0]])
>>> b = jnp.array([[4.0], [7.0]]) * u.meter
>>> X = sulax.triangular_solve(A, b, left_side=True, lower=True)
>>> u.get_unit(X) == u.meter
True