einshape#
- class brainunit.math.einshape(x, pattern)#
Parse a tensor shape to a dictionary mapping axis names to their lengths.
Use an underscore
_in the pattern to skip a dimension.- Parameters:
- Returns:
out – Dictionary mapping axis names to their integer lengths.
- Return type:
Examples
>>> import jax.numpy as jnp >>> import saiunit.math as sumath >>> x = jnp.zeros((2, 3, 5)) >>> sumath.einshape(x, 'batch _ w') {'batch': 2, 'w': 5}