ml_switcheroo.plugins.topk

Plugin for TopK Output Adaptation.

Addresses semantic mismatch: 1. PyTorch: values, indices = torch.topk(x, k) (Returns named-tuple-like result). 2. JAX: values, indices = jax.lax.top_k(x, k) (Returns tuple).

Transformation: Wraps the target function call in a collections.namedtuple factory construction to maintain attribute access (e.g. .values, .indices) while using a backend that returns raw tuples.

Decoupling Logic: - Strict API lookup for “TopK”. - If not found, returns original node.

Functions

transform_topk(→ libcst.CSTNode)

Hook: Wraps target top_k call in a NamedTuple constructor.

Module Contents

ml_switcheroo.plugins.topk.transform_topk(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]

Hook: Wraps target top_k call in a NamedTuple constructor.

Orchestrates the following:

  1. Looks up the target API for “TopK” (e.g., jax.lax.top_k).

  2. Strips arguments not supported by the target (e.g., largest, sorted).

  3. Injects import collections into the file preamble.

  4. Wraps the call execution in a collections.namedtuple factory to restore .values and .indices accessors expected by Torch code.

Parameters:
  • node – The original CST Call node.

  • ctx – HookContext for API lookup and preamble injection.

Returns:

The transformed call.

Return type:

cst.Call