ml_switcheroo.plugins.padding¶
Plugin for normalizing Padding arguments.
Addresses the semantic mismatch between: 1. PyTorch (pad(x, (left, right, top, bottom))): Pads starting from the last dimension. 2. JAX/NumPy (pad(x, ((n_b, n_a), (c_b, c_a), …))): Explicit per-dimension tuples.
This plugin transforms standard 4D tensor padding (images) into the explicit tuple-of-tuples format required by XLA compilers and NumPy-compatible libraries.
Functions¶
|
Hook: Transforms padding coordinate format from Torch style to NumPy style. |
Module Contents¶
- ml_switcheroo.plugins.padding.transform_padding(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Hook: Transforms padding coordinate format from Torch style to NumPy style.
Trigger: Operations mapped to ‘pad’ with requires_plugin: “padding_converter”.
- Decoupling:
Strictly looks up ‘Pad’ API. If missing, returns original node.
- Parameters:
node – The original CST Call node.
ctx – Hook Context containing target framework metadata.
- Returns:
The transformed CST Call node if mapping exists, else original.