batch_matmul

Contents

batch_matmul#

class saiunit.lax.batch_matmul(x, y, precision=None, **kwargs)#

Batch matrix multiplication.

Parameters:
  • x (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Left input array of shape [..., m, k].

  • y (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Right input array of shape [..., k, n].

  • precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – Numerical precision of the computation.

Returns:

result – The batch matrix product of shape [..., m, n]. The resulting unit is unit(x) * unit(y).

Return type:

saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex

Examples

>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> x = jnp.ones((2, 3, 4)) * u.meter
>>> y = jnp.ones((2, 4, 5)) * u.second
>>> result = sulax.batch_matmul(x, y)
>>> result.mantissa.shape
(2, 3, 5)