sort_key_val#
- class saiunit.lax.sort_key_val(keys, values, dimension=-1, is_stable=True, **kwargs)#
Sort
keysalongdimensionand apply the same permutation tovalues.- Parameters:
keys (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – The array of keys to sort.values (saiunit.Quantity |
Array|ndarray|bool|number|bool|int|float|complex) – The array of values to permute according to the sorted order ofkeys.dimension (
int) – The dimension along which to sort. Default is -1.is_stable (
bool) – Whether to use a stable sort. Default isTrue.
- Return type:
tuple[saiunit.Quantity |Array, saiunit.Quantity |Array]- Returns:
sorted_keys (jax.Array or Quantity) – The sorted keys. Preserves the unit of
keys.sorted_values (jax.Array or Quantity) – The values permuted to match the sorted order of
keys. Preserves the unit ofvalues.
Examples
>>> import saiunit as u >>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> keys = jnp.array([3.0, 1.0, 2.0]) * u.meter >>> vals = jnp.array([30, 10, 20]) >>> sorted_keys, sorted_vals = sulax.sort_key_val(keys, vals) >>> sorted_keys.mantissa Array([1., 2., 3.], dtype=float32) >>> sorted_vals Array([10, 20, 30], dtype=int32)