broadcast_shapes

Contents

broadcast_shapes#

class brainunit.lax.broadcast_shapes(*shapes, **kwargs)#

Return the shape that results from NumPy broadcasting of shapes.

Computes the shape that would result from broadcasting arrays with the given shapes, following standard NumPy broadcasting rules. This is a thin wrapper around jax.lax.broadcast_shapes() and does not involve any unit handling.

Parameters:

*shapes (tuple of int) – Two or more shapes to broadcast together. Each shape is a tuple of non-negative integers.

Returns:

result – The broadcasted shape.

Return type:

tuple of int

Raises:

ValueError – If the shapes are not broadcast-compatible (e.g. (2,) and (3,)).

See also

jax.lax.broadcast_shapes

The underlying JAX function.

numpy.broadcast_shapes

The NumPy equivalent.

Notes

Broadcasting rules:

  1. If the shapes differ in length, the shorter shape is padded with ones on the left.

  2. Dimensions are compatible when they are equal, or one of them is 1.

  3. The resulting dimension is the maximum of the two.

Examples

Basic broadcasting of two shapes:

>>> import saiunit.lax as sulax
>>> sulax.broadcast_shapes((2, 3), (3,))
(2, 3)

Broadcasting with dimension expansion:

>>> import saiunit.lax as sulax
>>> sulax.broadcast_shapes((1, 5), (3, 1))
(3, 5)

Broadcasting three shapes together:

>>> import saiunit.lax as sulax
>>> sulax.broadcast_shapes((1,), (3, 1), (1, 1, 5))
(1, 3, 5)