gather#
- class brainunit.math.gather(input, dim, index, **kwargs)#
Gather values along an axis specified by dim, according to index.
JAX implementation of
torch.gather.- Parameters:
input (
Array| saiunit.Quantity) – The source array or Quantity.dim (
int) – The axis along which to index.index (
Array) – The indices of elements to gather.
- Returns:
out – Array with the gathered elements. Quantity if input is a Quantity.
- Return type:
jax.Array, Quantity
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> a = jnp.array([[1, 2], [3, 4]]) * u.mV >>> index = jnp.array([[0, 0], [1, 0]]) >>> u.math.gather(a, 1, index)