index_take

Contents

index_take#

class saiunit.lax.index_take(src, idxs, axes, **kwargs)#

Take elements from an array at the given indices along the given axes.

Parameters:
  • src (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – The source array from which to take elements.

  • idxs (Array | ndarray | bool | number | bool | int | float | complex) – The indices of elements to extract.

  • axes (Sequence[int]) – The axes along which to index.

Returns:

result – The gathered elements. If src is a Quantity, the result preserves the same unit.

Return type:

saiunit.Quantity | Array

Examples

>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> x = jnp.array([[1., 2., 3.], [4., 5., 6.]]) * u.meter
>>> idx = jnp.array([0, 2])
>>> result = sulax.index_take(x, (idx,), axes=(1,))
>>> result.mantissa
Array([[1., 3.],
       [4., 6.]], dtype=float32)