ml_switcheroo.plugins.inplace_unroll¶
Plugin for unrolling in-place tensor operations to functional assignments.
PyTorch uses a trailing underscore convention (e.g., x.add_(y)) to denote in-place mutation. JAX and other functional frameworks require immutable operations, where the result must be assigned back to the variable (e.g., x = x.add(y)).
This plugin: 1. Detects calls ending in _ (e.g., add_). 2. Checks validity (excludes special methods like __init__). 3. Transforms the expression statement x.op_(y) into an assignment x = x.op(y).
Functions¶
|
Plugin Hook: Transforms in-place method calls to functional assignments. |
Module Contents¶
- ml_switcheroo.plugins.inplace_unroll.unroll_inplace_ops(node: libcst.Call | libcst.Expr, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call | libcst.Assign | libcst.Expr | libcst.BinaryOperation¶
Plugin Hook: Transforms in-place method calls to functional assignments.
- Scope:
This hook mechanism in ml-switcheroo typically operates on cst.Call nodes. However, to transform a standalone expression statement x.add_(y) into an assignment x = ā¦, we ideally need access to the statement container.
If the hook infrastructure passes cst.Call, we can only mutate the call itself. Replacing x.add_(y) with x = x.add(y) inside another expression is invalid syntax. Therefore, this logic assumes usage primarily in top-level expression statements or relies on the Rewriterās ability to handle statement expansion if this returns a wrapper.
Current Strategy: We strip the underscore to make the call functional. If the Call is part of an Expression Statement (standalone), it effectively becomes a no-op output (x.add(y) computed but lost) unless assigned.
Refined Strategy: Since we canāt easily ascend to the statement level from a Call hook, we strip the _ to ensure the API mapping (e.g. torch.add) works. The user receives x.add(y), which is valid execution but discards result. WARNING: This is a limitation of Call-level hooks. Ideally, we flag this.
Wait: The prompt asks to āwrap the call in an Assign nodeā. If we return cst.Assign replaces cst.Call, this is only valid if cst.Call was the root of an Expr statement. If cst.Call is inside z = x.add_(y), replacing it with z = (x = x.add(y)) is SyntaxError.
To support this robustly, we limit assignment wrapping to cases where we can infer safety, or we simply return the functional call x.add(y) and rely on the PivotRewriter (which calls this plugin) to handle the fact that in-place ops often return self.
Implementation: 1. Strip _ from method name: x.add_(y) -> x.add(y). 2. If Context allows or we detect usage context, we might attempt assignment injection,
but Call replacement with Assign is risky in nested contexts. However, standard PyTorch in-place ops x.add_(y) return x. So z = x.add_(y) -> z = x.add(y) is semantically correct conversion to functional. The only āLossā is that x itself isnāt updated in the scope.
Addressing the Prompt: āwraps the call in an Assign node targeting the receiverā. This implies x = x.add(y). This is ONLY valid if the node is a standalone expression. We will implement the strip logic. Generating assignment code via a Call hook is architecturally constrained.