triangular_solve#
- class brainunit.lax.triangular_solve(a, b, left_side=False, lower=False, transpose_a=False, conjugate_a=False, unit_diagonal=False)#
Triangular solve.
Solve the matrix equation \(\mathit{op}(A) \cdot X = B\) (when
left_side=True) or \(X \cdot \mathit{op}(A) = B\) (whenleft_side=False), where \(A\) is triangular.- Parameters:
a (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – A batch of triangular matrices with shape[..., m, m].b (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – A batch of right-hand-side matrices with shape[..., m, n](ifleft_side=True) or[..., n, m]otherwise.left_side (
bool) – Selects which equation to solve. Default isFalse.lower (
bool) – IfTrue, use the lower triangle ofa. Default isFalse.transpose_a (
bool) – IfTrue, transposeabefore solving. Default isFalse.conjugate_a (
bool) – IfTrue, use the complex conjugate ofa. Default isFalse.unit_diagonal (
bool) – IfTrue, the diagonal ofais assumed to be all ones. Default isFalse.
- Returns:
X – The solution with the same shape and dtype as
b. Ifbcarries a unit,Xpreserves that unit.- Return type:
saiunit.Quantity |
Array
Examples
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.lax as sulax >>> A = jnp.array([[2.0, 0.0], [1.0, 3.0]]) >>> b = jnp.array([[4.0], [7.0]]) * u.meter >>> X = sulax.triangular_solve(A, b, left_side=True, lower=True) >>> u.get_unit(X) == u.meter True