broadcast_shapes#
- class saiunit.lax.broadcast_shapes(*shapes, **kwargs)#
Return the shape that results from NumPy broadcasting of
shapes.Computes the shape that would result from broadcasting arrays with the given shapes, following standard NumPy broadcasting rules. This is a thin wrapper around
jax.lax.broadcast_shapes()and does not involve any unit handling.- Parameters:
*shapes (tuple of int) – Two or more shapes to broadcast together. Each shape is a tuple of non-negative integers.
- Returns:
result – The broadcasted shape.
- Return type:
- Raises:
ValueError – If the shapes are not broadcast-compatible (e.g.
(2,)and(3,)).
See also
jax.lax.broadcast_shapesThe underlying JAX function.
numpy.broadcast_shapesThe NumPy equivalent.
Notes
Broadcasting rules:
If the shapes differ in length, the shorter shape is padded with ones on the left.
Dimensions are compatible when they are equal, or one of them is 1.
The resulting dimension is the maximum of the two.
Examples
Basic broadcasting of two shapes:
>>> import saiunit.lax as sulax >>> sulax.broadcast_shapes((2, 3), (3,)) (2, 3)
Broadcasting with dimension expansion:
>>> import saiunit.lax as sulax >>> sulax.broadcast_shapes((1, 5), (3, 1)) (3, 5)
Broadcasting three shapes together:
>>> import saiunit.lax as sulax >>> sulax.broadcast_shapes((1,), (3, 1), (1, 1, 5)) (1, 3, 5)