ml_switcheroo.plugins.inplace_unroll¶
Plugin for unrolling in-place tensor operations to functional assignments.
PyTorch uses a trailing underscore convention (e.g., x.add_(y)) to denote in-place mutation. JAX and other functional frameworks require immutable operations, where the result must be assigned back to the variable (e.g., x = x.add(y)).
This plugin: 1. Detects calls ending in _ (e.g., add_). 2. Checks validity (excludes special methods like __init__). 3. Transforms the expression statement x.op_(y) into an assignment x = x.op(y).
Functions¶
|
Plugin Hook: Transforms in-place method calls to functional assignments. |
Module Contents¶
- ml_switcheroo.plugins.inplace_unroll.unroll_inplace_ops(node: libcst.Call | libcst.Expr, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call | libcst.Assign | libcst.Expr | libcst.BinaryOperation[source]¶
Plugin Hook: Transforms in-place method calls to functional assignments.
Scope
This hook mechanism in ml-switcheroo typically operates on cst.Call nodes. However, to transform a standalone expression statement x.add_(y) into an assignment x = âŠ, we ideally need access to the statement container.
If the hook infrastructure passes cst.Call, we can only mutate the call itself. Replacing x.add_(y) with x = x.add(y) inside another expression is invalid syntax. Therefore, this logic assumes usage primarily in top-level expression statements or relies on the Rewriterâs ability to handle statement expansion if this returns a wrapper.
Current Strategy: We strip the underscore to make the call functional. If the Call is part of an Expression Statement (standalone), it effectively becomes a no-op output (x.add(y) computed but lost) unless assigned.
Refined Strategy: Since we canât easily ascend to the statement level from a Call hook, we strip the _ to ensure the API mapping (e.g. torch.add) works. The user receives x.add(y), which is valid execution but discards result. WARNING: This is a limitation of Call-level hooks. Ideally, we flag this.
Implementation:
Strip _ from method name: x.add_(y) -> x.add(y).
If Context allows or we detect usage context, we might attempt assignment injection, but Call replacement with Assign is risky in nested contexts. However, standard PyTorch in-place ops x.add_(y) return x. So z = x.add_(y) -> z = x.add(y) is semantically correct conversion to functional. The only âLossâ is that x itself isnât updated in the scope.