ml_switcheroo.plugins.padding

Plugin for normalizing Padding arguments.

Addresses the semantic mismatch between: 1. PyTorch (pad(x, (left, right, top, bottom))): Pads starting from the last dimension. 2. JAX/NumPy (pad(x, ((n_b, n_a), (c_b, c_a), …))): Explicit per-dimension tuples.

This plugin transforms standard 4D tensor padding (images) into the explicit tuple-of-tuples format required by XLA compilers and NumPy-compatible libraries.

Functions

transform_padding(→ libcst.Call)

Hook: Transforms padding coordinate format from Torch style to NumPy style.

Module Contents

ml_switcheroo.plugins.padding.transform_padding(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]

Hook: Transforms padding coordinate format from Torch style to NumPy style.

Trigger: Operations mapped to ‘pad’ with requires_plugin: “padding_converter”.

Decoupling:

Strictly looks up ‘Pad’ API. If missing, returns original node.

Parameters:
  • node – The original CST Call node.

  • ctx – Hook Context containing target framework metadata.

Returns:

The transformed CST Call node if mapping exists, else original.