multi_dot

Contents

multi_dot#

class brainunit.math.multi_dot(arrays, *, precision=None, **kwargs)#

Efficiently compute matrix products between a sequence of arrays.

JAX internally uses the opt_einsum library to compute the most efficient operation order. The resulting unit is the product of the units of all input arrays.

Parameters:
  • arrays (Sequence[Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity]) – Sequence of arrays or quantities. All must be two-dimensional, except the first and last which may be one-dimensional.

  • precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – Either None (default), or a Precision enum value.

Returns:

output – An array representing the equivalent of reduce(jnp.matmul, arrays), evaluated in the optimal order. The resulting unit is the product of all input units.

Return type:

Array | saiunit.Quantity

Examples

>>> import saiunit as u
>>> import jax
>>> k1, k2 = jax.random.split(jax.random.key(0))
>>> a = jax.random.normal(k1, shape=(3, 4)) * u.meter
>>> b = jax.random.normal(k2, shape=(4, 2)) * u.second
>>> u.math.multi_dot([a, b])  # unit is meter * second