sort_key_val

Contents

sort_key_val#

class saiunit.lax.sort_key_val(keys, values, dimension=-1, is_stable=True, **kwargs)#

Sort keys along dimension and apply the same permutation to values.

Parameters:
  • keys (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – The array of keys to sort.

  • values (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – The array of values to permute according to the sorted order of keys.

  • dimension (int) – The dimension along which to sort. Default is -1.

  • is_stable (bool) – Whether to use a stable sort. Default is True.

Return type:

tuple[saiunit.Quantity | Array, saiunit.Quantity | Array]

Returns:

  • sorted_keys (jax.Array or Quantity) – The sorted keys. Preserves the unit of keys.

  • sorted_values (jax.Array or Quantity) – The values permuted to match the sorted order of keys. Preserves the unit of values.

Examples

>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> keys = jnp.array([3.0, 1.0, 2.0]) * u.meter
>>> vals = jnp.array([30, 10, 20])
>>> sorted_keys, sorted_vals = sulax.sort_key_val(keys, vals)
>>> sorted_keys.mantissa
Array([1., 2., 3.], dtype=float32)
>>> sorted_vals
Array([10, 20, 30], dtype=int32)