tridiagonal_solve#
- class saiunit.lax.tridiagonal_solve(dl, d, du, b)#
Solve a tridiagonal linear system.
Compute the solution \(X\) of the tridiagonal system \(A \cdot X = B\), where the tridiagonal matrix \(A\) is specified by its three diagonals.
- Parameters:
dl (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – Lower diagonal with shape[..., m].dl[i] = A[i, i-1];dl[0]is unused. Must have the same unit asdanddu.d (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – Main diagonal with shape[..., m].d[i] = A[i, i].du (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – Upper diagonal with shape[..., m].du[i] = A[i, i+1];du[m-1]is unused. Must have the same unit asdlandd.b (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – Right-hand-side matrix.
- Returns:
X – The solution of the tridiagonal system. If
bhas a unit,Xpreserves that unit.- Return type:
saiunit.Quantity |
Array- Raises:
saiunit.DimensionMismatchError – If
dl,d, anddudo not share the same unit.
Examples
>>> import jax.numpy as jnp >>> import saiunit as u >>> import saiunit.lax as sulax >>> dl = jnp.array([0.0, 1.0, 1.0]) >>> d = jnp.array([2.0, 2.0, 2.0]) >>> du = jnp.array([1.0, 1.0, 0.0]) >>> b = jnp.array([[1.0], [2.0], [3.0]]) * u.meter >>> X = sulax.tridiagonal_solve(dl, d, du, b) >>> u.get_unit(X) == u.meter True