ml_switcheroo.plugins.shape_packing¶
Plugin for Packing Shape Arguments.
This transformation converts variable-argument shape definitions into explicit tuple arguments required by certain frameworks.
Example
Source: x.view(1, 2, -1) (PyTorch style) Target: jnp.reshape(x, (1, 2, -1)) (JAX/NumPy style)
- Decoupling Logic:
This plugin does NOT enforce a framework whitelist. It executes unconditionally if wired. However, it relies on looking up “Reshape” or “View” in the semantics. If those definitions are missing for the target framework, it aborts.
Functions¶
|
Hook: Packs trailing positional arguments into a shape tuple. |
Module Contents¶
- ml_switcheroo.plugins.shape_packing.transform_shape_packing(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Hook: Packs trailing positional arguments into a shape tuple.
Logic: 1. Resolve Target API via “Reshape” or “View”. Abort if missing. 2. Packs arguments.
- Parameters:
node – The original CST Call node.
ctx – The hook execution context.
- Returns:
The transformed call with packed shape arguments.
- Return type:
cst.Call