argmax

Contents

argmax#

class brainunit.math.argmax(a, axis=None, **kwargs)#

Return the index of the maximum value along an axis.

Units are stripped before finding the maximum.

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity) – Input data.

  • axis (int | None) – Axis along which to operate. By default the flattened input is used.

Returns:

index – Index of the maximum value.

Return type:

Array

Examples

>>> import saiunit as u
>>> import jax.numpy as jnp
>>> u.math.argmax(jnp.array([1.0, 3.0, 2.0]))
Array(1, dtype=int32)
>>> q = jnp.array([1.0, 3.0, 2.0]) * u.meter
>>> u.math.argmax(q)
Array(1, dtype=int32)