multi_dot#
- class brainunit.linalg.multi_dot(arrays, *, precision=None, **kwargs)#
Efficiently compute matrix products between a sequence of arrays.
JAX internally uses the opt_einsum library to compute the most efficient operation order. The resulting unit is the product of the units of all input arrays.
- Parameters:
arrays (
Sequence[Array|ndarray|bool|number|bool|int|float|complex| saiunit.Quantity]) – Sequence of arrays or quantities. All must be two-dimensional, except the first and last which may be one-dimensional.precision (
None|str|Precision|tuple[str,str] |tuple[Precision,Precision] |DotAlgorithm|DotAlgorithmPreset) – EitherNone(default), or aPrecisionenum value.
- Returns:
output – An array representing the equivalent of
reduce(jnp.matmul, arrays), evaluated in the optimal order. The resulting unit is the product of all input units.- Return type:
Array| saiunit.Quantity
Examples
>>> import saiunit as u >>> import jax >>> k1, k2 = jax.random.split(jax.random.key(0)) >>> a = jax.random.normal(k1, shape=(3, 4)) * u.meter >>> b = jax.random.normal(k2, shape=(4, 2)) * u.second >>> u.math.multi_dot([a, b]) # unit is meter * second