argmax#
- class brainunit.math.argmax(a, axis=None, **kwargs)#
Return the index of the maximum value along an axis.
Units are stripped before finding the maximum.
- Parameters:
- Returns:
index – Index of the maximum value.
- Return type:
Array
Examples
>>> import saiunit as u >>> import jax.numpy as jnp >>> u.math.argmax(jnp.array([1.0, 3.0, 2.0])) Array(1, dtype=int32) >>> q = jnp.array([1.0, 3.0, 2.0]) * u.meter >>> u.math.argmax(q) Array(1, dtype=int32)