broadcast_in_dim#
- class brainunit.lax.broadcast_in_dim(operand, shape, broadcast_dimensions)[source]#
Broadcast an array into a target shape (XLA BroadcastInDim).
- Parameters:
operand (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – The input array.shape (
Sequence[int|Any]) – The target shape for the broadcast.broadcast_dimensions (
Sequence[int]) – Mapping from operand dimensions to target dimensions: dimension i of the operand becomes dimensionbroadcast_dimensions[i]of the result.
- Returns:
result – The broadcasted array. Preserves the unit of
operand.- Return type:
saiunit.Quantity |
Array
Examples
>>> import saiunit as u >>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> q = jnp.array([1.0, 2.0]) * u.meter >>> result = sulax.broadcast_in_dim(q, shape=(3, 2), broadcast_dimensions=(1,)) >>> result.mantissa.shape (3, 2)