top_k#
- class saiunit.lax.top_k(operand, k, **kwargs)#
Returns top
kvalues and their indices along the last axis ofoperand.- Parameters:
- Return type:
tuple[saiunit.Quantity |Array,Array|ndarray|bool|number|bool|int|float|complex]- Returns:
A tuple
(values, indices)wherevaluesis an array containing the top k values along the last axis.indicesis an array containing the indices corresponding to values.
Examples
Find the largest three values, and their indices, within an array:
>>> x = jnp.array([9., 3., 6., 4., 10.]) >>> values, indices = jax.lax.top_k(x, 3) >>> values Array([10., 9., 6.], dtype=float32) >>> indices Array([4, 0, 2], dtype=int32)