broadcast_in_dim

Contents

broadcast_in_dim#

class saiunit.lax.broadcast_in_dim(operand, shape, broadcast_dimensions)[source]#

Broadcast an array into a target shape (XLA BroadcastInDim).

Parameters:
  • operand (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – The input array.

  • shape (Sequence[int | Any]) – The target shape for the broadcast.

  • broadcast_dimensions (Sequence[int]) – Mapping from operand dimensions to target dimensions: dimension i of the operand becomes dimension broadcast_dimensions[i] of the result.

Returns:

result – The broadcasted array. Preserves the unit of operand.

Return type:

saiunit.Quantity | Array

Examples

>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([1.0, 2.0]) * u.meter
>>> result = sulax.broadcast_in_dim(q, shape=(3, 2), broadcast_dimensions=(1,))
>>> result.mantissa.shape
(3, 2)