ml_switcheroo.plugins.padding

Plugin for normalizing Padding arguments.

Addresses the semantic mismatch between: 1. PyTorch (pad(x, (left, right, top, bottom))): Pads starting from the last dimension. 2. JAX/NumPy (pad(x, ((n_b, n_a), (c_b, c_a), …))): Explicit per-dimension tuples.

This plugin transforms standard 4D tensor padding (images) into the explicit tuple-of-tuples format required by XLA compilers.

Functions

transform_padding(→ libcst.Call)

Hook: Transforms padding coordinate format.

Module Contents

ml_switcheroo.plugins.padding.transform_padding(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Hook: Transforms padding coordinate format. Trigger: Operations mapped to ‘pad’ with requires_plugin: “padding_converter”.