einreduce

Contents

einreduce#

class saiunit.math.einreduce(x, pattern, reduction, **axes_lengths)#

Combine reordering and reduction using reader-friendly notation.

einreduce provides combination of reordering and reduction using reader-friendly notation similar to einops.

Parameters:
  • x (Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity | Sequence[Array | ndarray | bool | number | bool | int | float | complex] | Sequence[saiunit.Quantity]) – Input tensor(s). A list of tensors of the same type and shape is also accepted.

  • pattern (str) – Reduction pattern in 'input -> output' form. Axes that appear on the left but not on the right are reduced.

  • reduction (str | Callable[[Array | ndarray | bool | number | bool | int | float | complex, Tuple[int, ...]], Array | ndarray | bool | number | bool | int | float | complex]) – Reduction operation to apply. A callable with signature f(tensor, reduced_axes) -> tensor may also be provided.

  • **axes_lengths (int) – Additional specifications for dimension sizes.

Returns:

out – The reduced tensor with the same type as the input.

Return type:

Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity

Examples

>>> import jax.numpy as jnp
>>> import saiunit.math as sumath
>>> x = jnp.ones((4, 3, 5))
>>> sumath.einreduce(x, 'b c h -> b c', 'sum').shape
(4, 3)
>>> sumath.einreduce(x, 'b c h -> b c', 'mean').shape
(4, 3)