ml_switcheroo.plugins.reshape

Plugin for “View” Semantics and Reshape strictness.

Addresses:

PyTorch tensor.view(*shape) requires contiguous memory and shares data. JAX jnp.reshape(arr, shape) works on any array, copying if necessary, producing immutable output.

Semantic Mismatch:

In PyTorch, view is often used as an assertion of zero-copy reshaping. In JAX, copy/view distinction is less relevant for correctness due to immutability, but relevant for performance.

Plugin Logic:
  1. Strict Mode: If config.strict_mode is enabled, this plugin can inject synchronization (block_until_ready()) or explicit copies to isolate performance artifacts, depending on the configuration policy. (Prompt request: injects block/copy).

  2. Argument Packing: Handles the conversion from varargs view(a, b) to tuple reshape((a, b)) if not already handled by a prior pass.

Functions

transform_view_semantics(→ libcst.Call)

Hook: Maps view -> reshape with optional strictness injections.

Module Contents

ml_switcheroo.plugins.reshape.transform_view_semantics(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Hook: Maps view -> reshape with optional strictness injections.

Transformation:

Input: x.view(a, b) Standard Output: jax.numpy.reshape(x, (a, b)) Strict Output: jax.numpy.reshape(x, (a, b)).block_until_ready()