gather

Contents

gather#

class saiunit.math.gather(input, dim, index, **kwargs)#

Gather values along an axis specified by dim, according to index.

JAX implementation of torch.gather.

Parameters:
  • input (Array | saiunit.Quantity) – The source array or Quantity.

  • dim (int) – The axis along which to index.

  • index (Array) – The indices of elements to gather.

Returns:

out – Array with the gathered elements. Quantity if input is a Quantity.

Return type:

jax.Array, Quantity

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> a = jnp.array([[1, 2], [3, 4]]) * u.mV
>>> index = jnp.array([[0, 0], [1, 0]])
>>> u.math.gather(a, 1, index)