cumsum

Contents

cumsum#

class brainunit.lax.cumsum(operand, axis=0, reverse=False, **kwargs)#

Compute a cumulative sum along axis.

Parameters:
  • operand (saiunit.Quantity | Array | ndarray | bool | number | bool | int | float | complex) – Input array.

  • axis (int) – The axis along which to compute the cumulative sum. Default is 0.

  • reverse (bool) – If True, compute the cumulative sum in reverse. Default is False.

Returns:

result – The cumulative sum array. Preserves the unit of operand.

Return type:

saiunit.Quantity | Array

Examples

>>> import saiunit as u
>>> import saiunit.lax as sulax
>>> import jax.numpy as jnp
>>> q = jnp.array([1.0, 2.0, 3.0, 4.0]) * u.meter
>>> result = sulax.cumsum(q)
>>> result.mantissa
Array([ 1.,  3.,  6., 10.], dtype=float32)