reduce

Contents

reduce#

class brainunit.lax.reduce(operands, init_values, computation, dimensions, **kwargs)#

Reduce an array along dimensions using a computation.

Wraps XLA’s Reduce operator.

init_values and computation together must form a monoid for correctness: init_values must be an identity of computation, and computation must be associative.

Parameters:
  • operands (Any) – The array(s) to reduce. If a Quantity, its underlying mantissa is extracted before the XLA operation. If a CustomArray, its .data attribute is unwrapped.

  • init_values (Any) – The initial value(s) for the reduction. Must be an identity element of computation. Accepts the same types as operands.

  • computation (Callable[[Any, Any], Any]) – A binary function used to combine elements (e.g. jax.lax.add).

  • dimensions (Sequence[int]) – The dimensions along which to reduce.

Returns:

result – The reduced result. Note that unit information is not preserved through the raw XLA reduce; see Notes.

Return type:

Any

Raises:

TypeError – If operands and init_values have incompatible types after unwrapping.

See also

jax.lax.reduce

The underlying JAX primitive.

jax.numpy.sum

A higher-level sum that preserves units in saiunit.

Notes

Because this function delegates directly to jax.lax.reduce(), the unit metadata carried by a Quantity is stripped before the reduction. If you need the result to retain its unit, consider using the higher-level wrappers in saiunit.math (e.g. saiunit.math.sum).

Examples

Reducing a plain array with lax.add:

>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> sulax.reduce(x, jnp.float32(0), lax.add, [0])
Array(6., dtype=float32)

Reducing a Quantity (unit is stripped, raw mantissa is reduced):

>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> from jax import lax
>>> q = jnp.array([1.0, 2.0, 3.0]) * u.meter
>>> sulax.reduce(q, jnp.float32(0) * u.meter, lax.add, [0])
Array(6., dtype=float32)