ml_switcheroo.plugins.reshape

Plugin for “View” Semantics and Reshape strictness.

Addresses:

PyTorch tensor.view(*shape) requires contiguous memory and shares data. JAX/NumPy reshape(arr, shape) works on any array.

Plugin Logic:
  1. Strict Mode: If config.strict_mode is enabled, this plugin can inject synchronization (e.g. block_until_ready()).

  2. Argument Packing: Packs varargs view(a, b) -> reshape((a, b)).

  3. Decoupling: Strictly relies on lookup API. If Reshape or View are not mapped in semantics for the target framework, returns original node.

Functions

transform_view_semantics(→ libcst.Call)

Hook: Maps view -> reshape with optional strictness injections.

Module Contents

ml_switcheroo.plugins.reshape.transform_view_semantics(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]

Hook: Maps view -> reshape with optional strictness injections.

Transformation:

Input: x.view(a, b) Standard Output: target_api(x, (a, b)) Strict Output: target_api(x, (a, b)).block_until_ready() [If defined in Traits]

Decoupling:

Lookup precedence: “Reshape” -> “View”. If lookup fails, aborts transformation.