TopK ==== Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of shape [a_0, a_1, ..., a_{n-1}] and integer argument k, return two outputs: * Value tensor of shape [a_0, a_1, ..., a_{axis-1}, k, a_{axis+1}, ... a_{n-1}] which contains the values of the top k elements ... **Abstract Signature:** ``TopK(X: Tensor, K: int, axis: int, largest: int, sorted: int)`` .. raw:: html
jax.lax.top_k