ml_switcheroo.plugins.shape_packing

Plugin for Packing Shape Arguments.

This transformation converts variable-argument shape definitions into explicit tuple arguments required by certain frameworks.

Example

Source: x.view(1, 2, -1) (PyTorch style) Target: jnp.reshape(x, (1, 2, -1)) (JAX/NumPy style)

Decoupling Logic:

This plugin does NOT enforce a framework whitelist. It executes unconditionally if wired. However, it relies on looking up “Reshape” or “View” in the semantics. If those definitions are missing for the target framework, it aborts.

Functions

transform_shape_packing(→ libcst.Call)

Hook: Packs trailing positional arguments into a shape tuple.

Module Contents

ml_switcheroo.plugins.shape_packing.transform_shape_packing(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]

Hook: Packs trailing positional arguments into a shape tuple.

Logic: 1. Resolve Target API via “Reshape” or “View”. Abort if missing. 2. Packs arguments.

Parameters:
  • node – The original CST Call node.

  • ctx – The hook execution context.

Returns:

The transformed call with packed shape arguments.

Return type:

cst.Call