reduce#
- class saiunit.lax.reduce(operands, init_values, computation, dimensions, **kwargs)#
Reduce an array along dimensions using a computation.
Wraps XLA’s Reduce operator.
init_valuesandcomputationtogether must form a monoid for correctness:init_valuesmust be an identity ofcomputation, andcomputationmust be associative.- Parameters:
operands (
Any) – The array(s) to reduce. If aQuantity, its underlying mantissa is extracted before the XLA operation. If aCustomArray, its.dataattribute is unwrapped.init_values (
Any) – The initial value(s) for the reduction. Must be an identity element ofcomputation. Accepts the same types asoperands.computation (
Callable[[Any,Any],Any]) – A binary function used to combine elements (e.g.jax.lax.add).dimensions (
Sequence[int]) – The dimensions along which to reduce.
- Returns:
result – The reduced result. Note that unit information is not preserved through the raw XLA reduce; see Notes.
- Return type:
- Raises:
TypeError – If
operandsandinit_valueshave incompatible types after unwrapping.
See also
jax.lax.reduceThe underlying JAX primitive.
jax.numpy.sumA higher-level sum that preserves units in saiunit.
Notes
Because this function delegates directly to
jax.lax.reduce(), the unit metadata carried by aQuantityis stripped before the reduction. If you need the result to retain its unit, consider using the higher-level wrappers insaiunit.math(e.g.saiunit.math.sum).Examples
Reducing a plain array with
lax.add:>>> import saiunit as u >>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> from jax import lax >>> x = jnp.array([1.0, 2.0, 3.0]) >>> sulax.reduce(x, jnp.float32(0), lax.add, [0]) Array(6., dtype=float32)
Reducing a
Quantity(unit is stripped, raw mantissa is reduced):>>> import saiunit as u >>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> from jax import lax >>> q = jnp.array([1.0, 2.0, 3.0]) * u.meter >>> sulax.reduce(q, jnp.float32(0) * u.meter, lax.add, [0]) Array(6., dtype=float32)