TopKΒΆ

Returns the k largest elements of the given input tensor along a given dimension.

Abstract Signature:

TopK(input: Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True)

PyTorch

API: torch.topk
Strategy: Direct Mapping

JAX (Core)

API: jax.lax.top_k
Strategy: Plugin (topk_adapter)

Keras

API: keras.ops.top_k
Strategy: Direct Mapping

TensorFlow

API: tf.math.top_k
Strategy: Plugin (tf_topk_adapter)

Apple MLX

API: mlx.core.topk
Strategy: Direct Mapping

Flax NNX

API: jax.lax.top_k
Strategy: Direct Mapping

PaxML / Praxis

API: jax.lax.top_k
Strategy: Direct Mapping