ml_switcheroo.plugins.topk¶
Plugin for TopK Output Adaptation.
Addresses semantic mismatch: 1. PyTorch: values, indices = torch.topk(x, k) (Returns named-tuple-like result). 2. JAX: values, indices = jax.lax.top_k(x, k) (Returns tuple).
Transformation: Wraps the target function call in a collections.namedtuple factory construction to maintain attribute access (e.g. .values, .indices) while using a backend that returns raw tuples.
Decoupling Logic: - Strict API lookup for “TopK”. - If not found, returns original node.
Functions¶
|
Hook: Wraps target top_k call in a NamedTuple constructor. |
Module Contents¶
- ml_switcheroo.plugins.topk.transform_topk(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook: Wraps target top_k call in a NamedTuple constructor.
Orchestrates the following:
Looks up the target API for “TopK” (e.g., jax.lax.top_k).
Strips arguments not supported by the target (e.g., largest, sorted).
Injects import collections into the file preamble.
Wraps the call execution in a collections.namedtuple factory to restore .values and .indices accessors expected by Torch code.
- Parameters:
node – The original CST Call node.
ctx – HookContext for API lookup and preamble injection.
- Returns:
The transformed call.
- Return type:
cst.Call