cumulative_sumΒΆ

Calculates the cumulative sum of elements in the input array x.

Abstract Signature:

cumulative_sum(x: array, axis: Optional[int], dtype: Optional[dtype], include_initial: bool)

JAX (Core)

API: jax.numpy.cumulative_sum
Strategy: Direct Mapping

Flax NNX

API: jax.numpy.cumulative_sum
Strategy: Direct Mapping