reduce_precision

Contents

reduce_precision#

class brainunit.lax.reduce_precision(operand, exponent_bits, mantissa_bits, **kwargs)[source]#

Reduce the precision of array elements.

Wraps XLA’s ReducePrecision operator.

When the input is a Quantity, the precision reduction is applied to the mantissa and the result is returned as a plain jax.Array (the unit is stripped).

Parameters:
  • operand (Array | ndarray | bool | number | bool | int | float | complex | saiunit.Quantity) – The input values whose precision will be reduced. If a Quantity, the precision reduction is applied to its mantissa and the result is a plain array. If a CustomArray, its .data attribute is unwrapped first.

  • exponent_bits (int) – Number of exponent bits in the reduced-precision format.

  • mantissa_bits (int) – Number of mantissa bits in the reduced-precision format.

Returns:

result – Array with reduced-precision values. Unit information from a Quantity input is not preserved.

Return type:

Array | ndarray | bool | number | bool | int | float | complex

See also

jax.lax.reduce_precision

The underlying JAX primitive.

Notes

This function simulates the effect of converting values to a lower-precision floating-point format and back. It is useful for exploring the numerical effects of quantization without actually changing the storage dtype.

The exponent_bits and mantissa_bits together define a virtual floating-point format. For example, exponent_bits=5 and mantissa_bits=10 correspond to IEEE float16.

Examples

Reducing precision of a plain array:

>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> x = jnp.array([1.123456, 2.123456], dtype=jnp.float32)
>>> sulax.reduce_precision(x, exponent_bits=5, mantissa_bits=10)
Array([1.123047, 2.123047], dtype=float32)

Reducing precision of a Quantity (mantissa is extracted, unit is stripped):

>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([1.123456, 2.123456], dtype=jnp.float32) * u.meter
>>> sulax.reduce_precision(q, exponent_bits=5, mantissa_bits=10)
Array([1.123047, 2.123047], dtype=float32)