GatherMMΒΆ

Matrix multiplication with matrix-level gather.

Abstract Signature:

GatherMM(a: Tensor, b: Tensor, lhs_indices: Tensor, rhs_indices: Tensor)

PyTorch

API: β€”
Strategy: Custom / Partial

JAX (Core)

API: β€”
Strategy: Custom / Partial

Apple MLX

API: mlx.core.gather_mm
Strategy: Direct Mapping