cumsum#
- class saiunit.lax.cumsum(operand, axis=0, reverse=False, **kwargs)#
Compute a cumulative sum along
axis.- Parameters:
- Returns:
result – The cumulative sum array. Preserves the unit of
operand.- Return type:
saiunit.Quantity |
Array
Examples
>>> import saiunit as u >>> import saiunit.lax as sulax >>> import jax.numpy as jnp >>> q = jnp.array([1.0, 2.0, 3.0, 4.0]) * u.meter >>> result = sulax.cumsum(q) >>> result.mantissa Array([ 1., 3., 6., 10.], dtype=float32)