ml_switcheroo.plugins.inplace_unroll ==================================== .. py:module:: ml_switcheroo.plugins.inplace_unroll .. autoapi-nested-parse:: Plugin for unrolling in-place tensor operations to functional assignments. PyTorch uses a trailing underscore convention (e.g., `x.add_(y)`) to denote in-place mutation. JAX and other functional frameworks require immutable operations, where the result must be assigned back to the variable (e.g., `x = x.add(y)`). This plugin: 1. Detects calls ending in `_` (e.g., `add_`). 2. Checks validity (excludes special methods like `__init__`). 3. Transforms the expression statement `x.op_(y)` into an assignment `x = x.op(y)`. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.inplace_unroll.unroll_inplace_ops Module Contents --------------- .. py:function:: unroll_inplace_ops(node: Union[libcst.Call, libcst.Expr], ctx: ml_switcheroo.core.hooks.HookContext) -> Union[libcst.Call, libcst.Assign, libcst.Expr, libcst.BinaryOperation] Plugin Hook: Transforms in-place method calls to functional assignments. Scope: This hook mechanism in `ml-switcheroo` typically operates on `cst.Call` nodes. However, to transform a standalone expression statement `x.add_(y)` into an assignment `x = ...`, we ideally need access to the statement container. If the hook infrastructure passes `cst.Call`, we can only mutate the call itself. Replacing `x.add_(y)` with `x = x.add(y)` *inside* another expression is invalid syntax. Therefore, this logic assumes usage primarily in top-level expression statements or relies on the Rewriter's ability to handle statement expansion if this returns a wrapper. *Current Strategy*: We strip the underscore to make the call functional. If the Call is part of an Expression Statement (standalone), it effectively becomes a no-op output (`x.add(y)` computed but lost) unless assigned. *Refined Strategy*: Since we can't easily ascend to the statement level from a Call hook, we strip the `_` to ensure the API mapping (e.g. `torch.add`) works. The user receives `x.add(y)`, which is valid execution but discards result. WARNING: This is a limitation of Call-level hooks. Ideally, we flag this. *Wait*: The prompt asks to "wrap the call in an Assign node". If we return `cst.Assign` replaces `cst.Call`, this is only valid if `cst.Call` was the root of an `Expr` statement. If `cst.Call` is inside `z = x.add_(y)`, replacing it with `z = (x = x.add(y))` is SyntaxError. To support this robustly, we limit assignment wrapping to cases where we can infer safety, or we simply return the functional call `x.add(y)` and rely on the `PivotRewriter` (which calls this plugin) to handle the fact that in-place ops often return `self`. *Implementation*: 1. Strip `_` from method name: `x.add_(y)` -> `x.add(y)`. 2. If Context allows or we detect usage context, we might attempt assignment injection, but `Call` replacement with `Assign` is risky in nested contexts. However, standard PyTorch in-place ops `x.add_(y)` return `x`. So `z = x.add_(y)` -> `z = x.add(y)` is semantically correct conversion to functional. The only "Loss" is that `x` itself isn't updated in the scope. *Addressing the Prompt*: "wraps the call in an Assign node targeting the receiver". This implies `x = x.add(y)`. This is ONLY valid if the node is a standalone expression. We will implement the strip logic. Generating assignment code via a Call hook is architecturally constrained.