tensorsolve

Contents

tensorsolve#

class saiunit.linalg.tensorsolve(a, b, axes=None, **kwargs)#

Solve the tensor equation a x = b for x.

SaiUnit implementation of numpy.linalg.tensorsolve().

The resulting unit is b.unit / a.unit.

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity) – Coefficient tensor. After reordering via axes (see below), shape must be (*b.shape, *x.shape).

  • b (Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity) – Right-hand side tensor.

  • axes (tuple[int, ...] | None) – Axes of a that should be moved to the end before solving.

Returns:

x – Solution x such that after reordering of axes of a, tensordot(a, x, x.ndim) is equivalent to b. The resulting unit is b.unit / a.unit.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity

See also

saiunit.linalg.solve

Solve a linear system of equations.

saiunit.linalg.tensorinv

Compute the tensor inverse.

saiunit.linalg.tensordot

Compute tensor dot product.

Examples

>>> import saiunit as u
>>> import jax
>>> key1, key2 = jax.random.split(jax.random.key(8675309))
>>> a = jax.random.normal(key1, shape=(2, 2, 4)) * u.meter
>>> b = jax.random.normal(key2, shape=(2, 2)) * u.second
>>> x = u.linalg.tensorsolve(a, b)
>>> x.shape
(4,)
>>> b_reconstructed = u.linalg.tensordot(a, x, axes=x.ndim)
>>> u.math.allclose(b, b_reconstructed)
Array(True, dtype=bool)