TopKΒΆ
Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of shape [a_0, a_1, β¦, a_{n-1}] and integer argument k, return two outputs: * Value tensor of shape [a_0, a_1, β¦, a_{axis-1}, k, a_{axis+1}, β¦ a_{n-1}] which contains the values of the top k elements β¦
Abstract Signature:
TopK(X: Tensor, K: int, axis: int, largest: int, sorted: int)
JAX (Core)
API:
jax.lax.top_kStrategy: Plugin (topk_adapter)