ml_switcheroo.plugins.arg_packing¶
Plugin for packing variable positional arguments into a sequence (tuple/list).
Essential for mapping vararg APIs (like torch.permute(x, 0, 2, 1)) to sequence-based APIs (like jax.numpy.transpose(x, axes=(0, 2, 1))).
Strategy: 1. Identify the operation (e.g. permute_dims). 2. Lookup the target API (e.g. jax.numpy.transpose). 3. Separate the primary input argument (first positional) from the varargs. 4. Pack the trailing varargs into a Tuple node. 5. Construct the new call with the sequence passed as a keyword argument (e.g. axes=…).
Functions¶
|
Plugin Hook: Packs trailing positional arguments into a keyword tuple. |
Module Contents¶
- ml_switcheroo.plugins.arg_packing.pack_varargs(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Plugin Hook: Packs trailing positional arguments into a keyword tuple.
- Triggers:
Operations marked with requires_plugin: “pack_varargs”. Designed for abstract operations like permute_dims.
- Transformation:
Input: torch.permute(x, 0, 2, 1) Output: jax.numpy.transpose(x, axes=(0, 2, 1)) (if mapped)
- Config:
Adapts the keyword name based on the target framework. - TensorFlow: uses ‘perm’ - Default (JAX/NumPy/MLX): uses ‘axes’
- Parameters:
node – The original CST Call node.
ctx – HookContext for API lookup directly from Semantics.
- Returns:
The transformed CST Call node, or original if target mapping is missing.