SegmentedMm =========== Segmented matrix multiplication. **Abstract Signature:** ``SegmentedMm(a: Tensor, b: Tensor, segments: Tensor)`` .. raw:: html

PyTorch

API:
Strategy: Plugin (segmented_ops)

JAX (Core)

API:
Strategy: Plugin (segmented_ops)

Keras

API:
Strategy: Plugin (segmented_ops)

TensorFlow

API: tf.math.segment_sum
Strategy: Plugin (segmented_ops)

Apple MLX

API: mlx.core.segmented_mm
Strategy: Direct Mapping

Flax NNX

API:
Strategy: Plugin (segmented_ops)

PaxML / Praxis

API:
Strategy: Plugin (segmented_ops)