TopKΒΆ

Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of shape [a_0, a_1, …, a_{n-1}] and integer argument k, return two outputs: * Value tensor of shape [a_0, a_1, …, a_{axis-1}, k, a_{axis+1}, … a_{n-1}] which contains the values of the top k elements …

Abstract Signature:

TopK(X: Tensor, K: int, axis: int, largest: int, sorted: int)

PyTorch

API: torch.topk
Strategy: Direct Mapping

JAX (Core)

API: jax.lax.top_k
Strategy: Plugin (topk_adapter)

Apple MLX

API: mlx.core.topk
Strategy: Direct Mapping