reduce_precision#
- class brainunit.lax.reduce_precision(operand, exponent_bits, mantissa_bits, **kwargs)[source]#
Reduce the precision of array elements.
Wraps XLA’s ReducePrecision operator.
When the input is a
Quantity, the precision reduction is applied to the mantissa and the result is returned as a plainjax.Array(the unit is stripped).- Parameters:
operand (
Array|ndarray|bool|number|bool|int|float|complex| saiunit.Quantity) – The input values whose precision will be reduced. If aQuantity, the precision reduction is applied to its mantissa and the result is a plain array. If aCustomArray, its.dataattribute is unwrapped first.exponent_bits (
int) – Number of exponent bits in the reduced-precision format.mantissa_bits (
int) – Number of mantissa bits in the reduced-precision format.
- Returns:
result – Array with reduced-precision values. Unit information from a
Quantityinput is not preserved.- Return type:
Array|ndarray|bool|number|bool|int|float|complex
See also
jax.lax.reduce_precisionThe underlying JAX primitive.
Notes
This function simulates the effect of converting values to a lower-precision floating-point format and back. It is useful for exploring the numerical effects of quantization without actually changing the storage dtype.
The
exponent_bitsandmantissa_bitstogether define a virtual floating-point format. For example,exponent_bits=5andmantissa_bits=10correspond to IEEE float16.Examples
Reducing precision of a plain array:
>>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> x = jnp.array([1.123456, 2.123456], dtype=jnp.float32) >>> sulax.reduce_precision(x, exponent_bits=5, mantissa_bits=10) Array([1.123047, 2.123047], dtype=float32)
Reducing precision of a
Quantity(mantissa is extracted, unit is stripped):>>> import saiunit as u >>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> q = jnp.array([1.123456, 2.123456], dtype=jnp.float32) * u.meter >>> sulax.reduce_precision(q, exponent_bits=5, mantissa_bits=10) Array([1.123047, 2.123047], dtype=float32)