ml_switcheroo.plugins.padding¶
Plugin for normalizing Padding arguments.
Addresses the semantic mismatch between: 1. PyTorch (pad(x, (left, right, top, bottom))): Pads starting from the last dimension. 2. JAX/NumPy (pad(x, ((n_b, n_a), (c_b, c_a), …))): Explicit per-dimension tuples.
This plugin transforms standard 4D tensor padding (images) into the explicit tuple-of-tuples format required by XLA compilers.
Functions¶
|
Hook: Transforms padding coordinate format. |
Module Contents¶
- ml_switcheroo.plugins.padding.transform_padding(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Hook: Transforms padding coordinate format. Trigger: Operations mapped to ‘pad’ with requires_plugin: “padding_converter”.