einshape

Contents

einshape#

class brainunit.math.einshape(x, pattern)#

Parse a tensor shape to a dictionary mapping axis names to their lengths.

Use an underscore _ in the pattern to skip a dimension.

Parameters:
  • x (Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity) – Tensor of any supported framework.

  • pattern (str) – Space-separated axis names. Use _ to skip an axis.

Returns:

out – Dictionary mapping axis names to their integer lengths.

Return type:

dict

Examples

>>> import jax.numpy as jnp
>>> import saiunit.math as sumath
>>> x = jnp.zeros((2, 3, 5))
>>> sumath.einshape(x, 'batch _ w')
{'batch': 2, 'w': 5}