tridiagonal_solve

tridiagonal_solve#

class brainunit.lax.tridiagonal_solve(dl, d, du, b)#

Solve a tridiagonal linear system.

Compute the solution \(X\) of the tridiagonal system \(A \cdot X = B\), where the tridiagonal matrix \(A\) is specified by its three diagonals.

Parameters:
  • dl (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Lower diagonal with shape [..., m]. dl[i] = A[i, i-1]; dl[0] is unused. Must have the same unit as d and du.

  • d (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Main diagonal with shape [..., m]. d[i] = A[i, i].

  • du (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Upper diagonal with shape [..., m]. du[i] = A[i, i+1]; du[m-1] is unused. Must have the same unit as dl and d.

  • b (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Right-hand-side matrix.

Returns:

X – The solution of the tridiagonal system. If b has a unit, X preserves that unit.

Return type:

saiunit.Quantity | Array

Raises:

saiunit.DimensionMismatchError – If dl, d, and du do not share the same unit.

Examples

>>> import jax.numpy as jnp
>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> dl = jnp.array([0.0, 1.0, 1.0])
>>> d  = jnp.array([2.0, 2.0, 2.0])
>>> du = jnp.array([1.0, 1.0, 0.0])
>>> b  = jnp.array([[1.0], [2.0], [3.0]]) * u.meter
>>> X = sulax.tridiagonal_solve(dl, d, du, b)
>>> u.get_unit(X) == u.meter
True