index_take#
- class saiunit.lax.index_take(src, idxs, axes, **kwargs)#
Take elements from an array at the given indices along the given axes.
- Parameters:
- Returns:
result – The gathered elements. If
srcis aQuantity, the result preserves the same unit.- Return type:
saiunit.Quantity |
Array
Examples
>>> import saiunit as u >>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> x = jnp.array([[1., 2., 3.], [4., 5., 6.]]) * u.meter >>> idx = jnp.array([0, 2]) >>> result = sulax.index_take(x, (idx,), axes=(1,)) >>> result.mantissa Array([[1., 3.], [4., 6.]], dtype=float32)