einsum

Contents

einsum#

class saiunit.math.einsum(subscripts, /, *operands, optimize='optimal', precision=None, preferred_element_type=None)#

Einstein summation for arrays and quantities.

einsum is a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays. It has a somewhat complicated overloaded API; the arguments below reflect the most common calling convention. The Examples section below demonstrates some of the alternative calling conventions.

Parameters:
  • subscripts (str) – Subscript string in Einstein summation notation, with axes names separated by commas and an optional -> to specify the output.

  • *operands (array_like or Quantity) – Sequence of one or more arrays / quantities corresponding to the subscripts.

  • optimize (str | bool | list[tuple[int, ...]]) – Optimization strategy. Defaults to "optimal".

  • precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – Precision of the computation. Default is None.

  • preferred_element_type (str | type[Any] | dtype | SupportsDType | None) – Accumulation and result dtype. Default is None.

Returns:

out – Result of the Einstein summation. When operands carry physical units, the output unit is derived from the contraction.

Return type:

Array

Examples

The mechanics of einsum are perhaps best demonstrated by example. Here we show how to use einsum to compute a number of quantities from one or more arrays. For more discussion and examples of einsum, see the documentation of numpy.einsum().

>>> import saiunit as bu
>>> M = bu.math.arange(16).reshape(4, 4) * bu.ohm
>>> x = bu.math.arange(4) * bu.mA
>>> y = bu.math.array([5, 4, 3, 2]) * bu.mV

Vector product

>>> bu.math.einsum('i,i', x, y)
16.0 uW
>>> bu.math.vecdot(x, y)
16.0 uW

Here are some alternative einsum calling conventions to compute the same result:

>>> bu.math.einsum('i,i->', x, y)  # explicit form
16.0 uW
>>> bu.math.einsum(x, (0,), y, (0,))  # implicit form via indices
16.0 uW
>>> bu.math.einsum(x, (0,), y, (0,), ())  # explicit form via indices
16.0 uW

Matrix product

>>> bu.math.einsum('ij,j->i', M, x)  # explicit form
ArrayImpl([14., 38., 62., 86.], dtype=float32) * mvolt
>>> bu.math.matmul(M, x)
Array([14, 38, 62, 86], dtype=float32) * mvolt

Here are some alternative einsum calling conventions to compute the same result:

>>> bu.math.einsum('ij,j', M, x) # implicit form
Array([14, 38, 62, 86], dtype=float32) * mvolt
>>> bu.math.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices
Array([14, 38, 62, 86], dtype=float32)
>>> bu.math.einsum(M, (0, 1), x, (1,))  # implicit form via indices
Array([14, 38, 62, 86], dtype=float32) * mvolt

Outer product

>>> bu.math.einsum("i,j->ij", x, y)
Array([[ 0,  0,  0,  0],
         [ 5,  4,  3,  2],
         [10,  8,  6,  4],
         [15, 12,  9,  6]], dtype=float32) * uwatt
>>> bu.math.outer(x, y)
Array([[ 0,  0,  0,  0],
         [ 5,  4,  3,  2],
         [10,  8,  6,  4],
         [15, 12,  9,  6]], dtype=float32) * uwatt

Some other ways of computing outer products:

>>> bu.math.einsum("i,j", x, y)  # implicit form
Array([[ 0,  0,  0,  0],
         [ 5,  4,  3,  2],
         [10,  8,  6,  4],
         [15, 12,  9,  6]], dtype=float32) * uwatt
>>> bu.math.einsum(x, (0,), y, (1,), (0, 1))  # explicit form via indices
Array([[ 0,  0,  0,  0],
         [ 5,  4,  3,  2],
         [10,  8,  6,  4],
         [15, 12,  9,  6]], dtype=float32) * uwatt
>>> bu.math.einsum(x, (0,), y, (1,))  # implicit form via indices
Array([[ 0,  0,  0,  0],
         [ 5,  4,  3,  2],
         [10,  8,  6,  4],
         [15, 12,  9,  6]], dtype=float32) * uwatt

1D array sum

>>> bu.math.einsum("i->", x)  # requires explicit form
Array(6, dtype=float32) * mA
>>> bu.math.einsum(x, (0,), ())  # explicit form via indices
Array(6, dtype=float32) * mA
>>> bu.math.sum(x)
Array(6, dtype=float32) * mA

Sum along an axis

>>> bu.math.einsum("...j->...", M)  # requires explicit form
Array([ 6, 22, 38, 54], dtype=float32) * ohm
>>> bu.math.einsum(M, (..., 0), (...,))  # explicit form via indices
Array([ 6, 22, 38, 54], dtype=float32) * ohm
>>> M.sum(-1)
Array([ 6, 22, 38, 54], dtype=float32) * ohm

Matrix transpose

>>> y = bu.math.array([[1, 2, 3],
...                    [4, 5, 6]]) * bu.mV
>>> bu.math.einsum("ij->ji", y)  # explicit form
Array([[1, 4],
         [2, 5],
         [3, 6]], dtype=float32) * mV
>>> bu.math.einsum("ji", y)  # implicit form
Array([[1, 4],
         [2, 5],
         [3, 6]], dtype=float32) * mV
>>> bu.math.einsum(y, (1, 0))  # implicit form via indices
Array([[1, 4],
         [2, 5],
         [3, 6]], dtype=float32) * mV
>>> bu.math.einsum(y, (0, 1), (1, 0))  # explicit form via indices
Array([[1, 4],
         [2, 5],
         [3, 6]], dtype=float32) * mV
>>> bu.math.transpose(y)
Array([[1, 4],
         [2, 5],
         [3, 6]], dtype=float32) * mV

Matrix diagonal

>>> bu.math.einsum("ii->i", M)
Array([ 0,  5, 10, 15], dtype=float32) * ohm
>>> bu.math.diagonal(M)
Array([ 0,  5, 10, 15], dtype=float32) * ohm

Matrix trace

>>> bu.math.einsum("ii", M)
Array(30, dtype=float32) * ohm
>>> bu.math.trace(M)
Array(30, dtype=float32) * ohm

Tensor products

>>> x = bu.math.arange(30).reshape(2, 3, 5) * bu.mA
>>> y = bu.math.arange(60).reshape(3, 4, 5) * bu.ohm
>>> bu.math.einsum('ijk,jlk->il', x, y)  # explicit form
Array([[ 3340,  3865,  4390,  4915],
         [ 8290,  9940, 11590, 13240]], dtype=float32) * volt
>>> bu.math.tensordot(x, y, axes=((1, 2), (0, 2)))
Array([[ 3340,  3865,  4390,  4915],
         [ 8290,  9940, 11590, 13240]], dtype=float32) * volt
>>> bu.math.einsum('ijk,jlk', x, y)  # implicit form
Array([[ 3340,  3865,  4390,  4915],
         [ 8290,  9940, 11590, 13240]], dtype=float32) * volt
>>> bu.math.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3))  # explicit form via indices
Array([[ 3340,  3865,  4390,  4915],
         [ 8290,  9940, 11590, 13240]], dtype=float32) * volt
>>> bu.math.einsum(x, (0, 1, 2), y, (1, 3, 2))  # implicit form via indices
Array([[ 3340,  3865,  4390,  4915],
         [ 8290,  9940, 11590, 13240]], dtype=float32) * volt

Chained dot products

>>> w = bu.math.arange(5, 9).reshape(2, 2) * bu.mA
>>> x = bu.math.arange(6).reshape(2, 3) * bu.ohm
>>> y = bu.math.arange(-2, 4).reshape(3, 2) * bu.mV
>>> z = bu.math.array([[2, 4, 6], [3, 5, 7]]) * bu.mA
>>> bu.math.einsum('ij,jk,kl,lm->im', w, x, y, z)
Array([[ 481,  831, 1181],
         [ 651, 1125, 1599]], dtype=float32) * metre ** 4 * kilogram ** 2 * second ** -6 * amp ** -1
>>> bu.math.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4))  # implicit, via indices
Array([[ 481,  831, 1181],
         [ 651, 1125, 1599]], dtype=float32) * metre ** 4 * kilogram ** 2 * second ** -6 * amp ** -1
>>> w @ x @ y @ z  # direct chain of matmuls
Array([[ 481,  831, 1181],
         [ 651, 1125, 1599]], dtype=float32) * metre ** 4 * kilogram ** 2 * second ** -6 * amp ** -1
>>> bu.math.multi_dot([w, x, y, z])
Array([[ 481,  831, 1181],
         [ 651, 1125, 1599]], dtype=float32) * metre ** 4 * kilogram ** 2 * second ** -6 * amp ** -1