einsum#
- class saiunit.math.einsum(subscripts, /, *operands, optimize='optimal', precision=None, preferred_element_type=None)#
Einstein summation for arrays and quantities.
einsumis a powerful and generic API for computing various reductions, inner products, outer products, axis reorderings, and combinations thereof across one or more input arrays. It has a somewhat complicated overloaded API; the arguments below reflect the most common calling convention. The Examples section below demonstrates some of the alternative calling conventions.- Parameters:
subscripts (str) – Subscript string in Einstein summation notation, with axes names separated by commas and an optional
->to specify the output.*operands (array_like or Quantity) – Sequence of one or more arrays / quantities corresponding to the subscripts.
optimize (
str|bool|list[tuple[int,...]]) – Optimization strategy. Defaults to"optimal".precision (
None|str|Precision|tuple[str,str] |tuple[Precision,Precision] |DotAlgorithm|DotAlgorithmPreset) – Precision of the computation. Default isNone.preferred_element_type (
str|type[Any] |dtype|SupportsDType|None) – Accumulation and result dtype. Default isNone.
- Returns:
out – Result of the Einstein summation. When operands carry physical units, the output unit is derived from the contraction.
- Return type:
Array
Examples
The mechanics of
einsumare perhaps best demonstrated by example. Here we show how to useeinsumto compute a number of quantities from one or more arrays. For more discussion and examples ofeinsum, see the documentation ofnumpy.einsum().>>> import saiunit as bu >>> M = bu.math.arange(16).reshape(4, 4) * bu.ohm >>> x = bu.math.arange(4) * bu.mA >>> y = bu.math.array([5, 4, 3, 2]) * bu.mV
Vector product
>>> bu.math.einsum('i,i', x, y) 16.0 uW >>> bu.math.vecdot(x, y) 16.0 uW
Here are some alternative
einsumcalling conventions to compute the same result:>>> bu.math.einsum('i,i->', x, y) # explicit form 16.0 uW >>> bu.math.einsum(x, (0,), y, (0,)) # implicit form via indices 16.0 uW >>> bu.math.einsum(x, (0,), y, (0,), ()) # explicit form via indices 16.0 uW
Matrix product
>>> bu.math.einsum('ij,j->i', M, x) # explicit form ArrayImpl([14., 38., 62., 86.], dtype=float32) * mvolt >>> bu.math.matmul(M, x) Array([14, 38, 62, 86], dtype=float32) * mvolt
Here are some alternative
einsumcalling conventions to compute the same result:>>> bu.math.einsum('ij,j', M, x) # implicit form Array([14, 38, 62, 86], dtype=float32) * mvolt >>> bu.math.einsum(M, (0, 1), x, (1,), (0,)) # explicit form via indices Array([14, 38, 62, 86], dtype=float32) >>> bu.math.einsum(M, (0, 1), x, (1,)) # implicit form via indices Array([14, 38, 62, 86], dtype=float32) * mvolt
Outer product
>>> bu.math.einsum("i,j->ij", x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=float32) * uwatt >>> bu.math.outer(x, y) Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=float32) * uwatt
Some other ways of computing outer products:
>>> bu.math.einsum("i,j", x, y) # implicit form Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=float32) * uwatt >>> bu.math.einsum(x, (0,), y, (1,), (0, 1)) # explicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=float32) * uwatt >>> bu.math.einsum(x, (0,), y, (1,)) # implicit form via indices Array([[ 0, 0, 0, 0], [ 5, 4, 3, 2], [10, 8, 6, 4], [15, 12, 9, 6]], dtype=float32) * uwatt
1D array sum
>>> bu.math.einsum("i->", x) # requires explicit form Array(6, dtype=float32) * mA >>> bu.math.einsum(x, (0,), ()) # explicit form via indices Array(6, dtype=float32) * mA >>> bu.math.sum(x) Array(6, dtype=float32) * mA
Sum along an axis
>>> bu.math.einsum("...j->...", M) # requires explicit form Array([ 6, 22, 38, 54], dtype=float32) * ohm >>> bu.math.einsum(M, (..., 0), (...,)) # explicit form via indices Array([ 6, 22, 38, 54], dtype=float32) * ohm >>> M.sum(-1) Array([ 6, 22, 38, 54], dtype=float32) * ohm
Matrix transpose
>>> y = bu.math.array([[1, 2, 3], ... [4, 5, 6]]) * bu.mV >>> bu.math.einsum("ij->ji", y) # explicit form Array([[1, 4], [2, 5], [3, 6]], dtype=float32) * mV >>> bu.math.einsum("ji", y) # implicit form Array([[1, 4], [2, 5], [3, 6]], dtype=float32) * mV >>> bu.math.einsum(y, (1, 0)) # implicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=float32) * mV >>> bu.math.einsum(y, (0, 1), (1, 0)) # explicit form via indices Array([[1, 4], [2, 5], [3, 6]], dtype=float32) * mV >>> bu.math.transpose(y) Array([[1, 4], [2, 5], [3, 6]], dtype=float32) * mV
Matrix diagonal
>>> bu.math.einsum("ii->i", M) Array([ 0, 5, 10, 15], dtype=float32) * ohm >>> bu.math.diagonal(M) Array([ 0, 5, 10, 15], dtype=float32) * ohm
Matrix trace
>>> bu.math.einsum("ii", M) Array(30, dtype=float32) * ohm >>> bu.math.trace(M) Array(30, dtype=float32) * ohm
Tensor products
>>> x = bu.math.arange(30).reshape(2, 3, 5) * bu.mA >>> y = bu.math.arange(60).reshape(3, 4, 5) * bu.ohm >>> bu.math.einsum('ijk,jlk->il', x, y) # explicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=float32) * volt >>> bu.math.tensordot(x, y, axes=((1, 2), (0, 2))) Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=float32) * volt >>> bu.math.einsum('ijk,jlk', x, y) # implicit form Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=float32) * volt >>> bu.math.einsum(x, (0, 1, 2), y, (1, 3, 2), (0, 3)) # explicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=float32) * volt >>> bu.math.einsum(x, (0, 1, 2), y, (1, 3, 2)) # implicit form via indices Array([[ 3340, 3865, 4390, 4915], [ 8290, 9940, 11590, 13240]], dtype=float32) * volt
Chained dot products
>>> w = bu.math.arange(5, 9).reshape(2, 2) * bu.mA >>> x = bu.math.arange(6).reshape(2, 3) * bu.ohm >>> y = bu.math.arange(-2, 4).reshape(3, 2) * bu.mV >>> z = bu.math.array([[2, 4, 6], [3, 5, 7]]) * bu.mA >>> bu.math.einsum('ij,jk,kl,lm->im', w, x, y, z) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=float32) * metre ** 4 * kilogram ** 2 * second ** -6 * amp ** -1 >>> bu.math.einsum(w, (0, 1), x, (1, 2), y, (2, 3), z, (3, 4)) # implicit, via indices Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=float32) * metre ** 4 * kilogram ** 2 * second ** -6 * amp ** -1 >>> w @ x @ y @ z # direct chain of matmuls Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=float32) * metre ** 4 * kilogram ** 2 * second ** -6 * amp ** -1 >>> bu.math.multi_dot([w, x, y, z]) Array([[ 481, 831, 1181], [ 651, 1125, 1599]], dtype=float32) * metre ** 4 * kilogram ** 2 * second ** -6 * amp ** -1