SegmentedMmΒΆ

Segmented matrix multiplication.

Abstract Signature:

SegmentedMm(a: Tensor, b: Tensor, segments: Tensor)

PyTorch

API: β€”
Strategy: Plugin (segmented_ops)

JAX (Core)

API: β€”
Strategy: Plugin (segmented_ops)

Keras

API: β€”
Strategy: Plugin (segmented_ops)

TensorFlow

API: tf.math.segment_sum
Strategy: Plugin (segmented_ops)

Apple MLX

API: mlx.core.segmented_mm
Strategy: Direct Mapping

Flax NNX

API: β€”
Strategy: Plugin (segmented_ops)

PaxML / Praxis

API: β€”
Strategy: Plugin (segmented_ops)