TopKΒΆ
Returns the k largest elements of the given input tensor along a given dimension.
Abstract Signature:
TopK(input: Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True)
JAX (Core)
API:
jax.lax.top_kStrategy: Plugin (topk_adapter)
TensorFlow
API:
tf.math.top_kStrategy: Plugin (tf_topk_adapter)