SegmentedMmΒΆ
Segmented matrix multiplication.
Abstract Signature:
SegmentedMm(a: Tensor, b: Tensor, segments: Tensor)
PyTorch
API:
βStrategy: Plugin (segmented_ops)
JAX (Core)
API:
βStrategy: Plugin (segmented_ops)
Keras
API:
βStrategy: Plugin (segmented_ops)
TensorFlow
API:
tf.math.segment_sumStrategy: Plugin (segmented_ops)
Flax NNX
API:
βStrategy: Plugin (segmented_ops)
PaxML / Praxis
API:
βStrategy: Plugin (segmented_ops)