batch_matmul#
- class brainunit.lax.batch_matmul(x, y, precision=None, **kwargs)#
Batch matrix multiplication.
- Parameters:
x (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – Left input array of shape[..., m, k].y (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – Right input array of shape[..., k, n].precision (
None|str|Precision|tuple[str,str] |tuple[Precision,Precision] |DotAlgorithm|DotAlgorithmPreset) – Numerical precision of the computation.
- Returns:
result – The batch matrix product of shape
[..., m, n]. The resulting unit isunit(x) * unit(y).- Return type:
saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex
Examples
>>> import saiunit as u >>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> x = jnp.ones((2, 3, 4)) * u.meter >>> y = jnp.ones((2, 4, 5)) * u.second >>> result = sulax.batch_matmul(x, y) >>> result.mantissa.shape (2, 3, 5)