ml_switcheroo.plugins.reshape¶
Plugin for “View” Semantics and Reshape strictness.
- Addresses:
PyTorch tensor.view(*shape) requires contiguous memory and shares data. JAX/NumPy reshape(arr, shape) works on any array.
- Plugin Logic:
Strict Mode: If config.strict_mode is enabled, this plugin can inject synchronization (e.g. block_until_ready()).
Argument Packing: Packs varargs view(a, b) -> reshape((a, b)).
Decoupling: Strictly relies on lookup API. If Reshape or View are not mapped in semantics for the target framework, returns original node.
Functions¶
|
Hook: Maps view -> reshape with optional strictness injections. |
Module Contents¶
- ml_switcheroo.plugins.reshape.transform_view_semantics(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Hook: Maps view -> reshape with optional strictness injections.
- Transformation:
Input: x.view(a, b) Standard Output: target_api(x, (a, b)) Strict Output: target_api(x, (a, b)).block_until_ready() [If defined in Traits]
- Decoupling:
Lookup precedence: “Reshape” -> “View”. If lookup fails, aborts transformation.