GatherMMΒΆ
Matrix multiplication with matrix-level gather.
Abstract Signature:
GatherMM(a: Tensor, b: Tensor, lhs_indices: Tensor, rhs_indices: Tensor)
PyTorch
API:
βStrategy: Custom / Partial
JAX (Core)
API:
βStrategy: Custom / Partial