ml_switcheroo.plugins.arg_packing

Plugin for packing variable positional arguments into a sequence (tuple/list).

Essential for mapping vararg APIs (like torch.permute(x, 0, 2, 1)) to sequence-based APIs (like jax.numpy.transpose(x, axes=(0, 2, 1))).

Strategy: 1. Identify the operation (e.g. permute_dims). 2. Lookup the target API (e.g. jax.numpy.transpose). 3. Separate the primary input argument (first positional) from the varargs. 4. Pack the trailing varargs into a Tuple node. 5. Construct the new call with the sequence passed as a keyword argument (e.g. axes=…).

Functions

pack_varargs(→ libcst.Call)

Plugin Hook: Packs trailing positional arguments into a keyword tuple.

Module Contents

ml_switcheroo.plugins.arg_packing.pack_varargs(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Plugin Hook: Packs trailing positional arguments into a keyword tuple.

Triggers:

Operations marked with requires_plugin: “pack_varargs”. Designed for abstract operations like permute_dims.

Transformation:

Input: torch.permute(x, 0, 2, 1) Output: jax.numpy.transpose(x, axes=(0, 2, 1)) (if mapped)

Config:

Adapts the keyword name based on the target framework. - TensorFlow: uses ‘perm’ - Default (JAX/NumPy/MLX): uses ‘axes’

Parameters:
  • node – The original CST Call node.

  • ctx – HookContext for API lookup directly from Semantics.

Returns:

The transformed CST Call node, or original if target mapping is missing.