ml_switcheroo.plugins.inplace_unroll¶

Plugin for unrolling in-place tensor operations to functional assignments.

PyTorch uses a trailing underscore convention (e.g., x.add_(y)) to denote in-place mutation. JAX and other functional frameworks require immutable operations, where the result must be assigned back to the variable (e.g., x = x.add(y)).

This plugin: 1. Detects calls ending in _ (e.g., add_). 2. Checks validity (excludes special methods like __init__). 3. Transforms the expression statement x.op_(y) into an assignment x = x.op(y).

Functions¶

unroll_inplace_ops(→ Union[libcst.Call, libcst.Assign, ...)

Plugin Hook: Transforms in-place method calls to functional assignments.

Module Contents¶

ml_switcheroo.plugins.inplace_unroll.unroll_inplace_ops(node: libcst.Call | libcst.Expr, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call | libcst.Assign | libcst.Expr | libcst.BinaryOperation[source]¶

Plugin Hook: Transforms in-place method calls to functional assignments.

Scope

This hook mechanism in ml-switcheroo typically operates on cst.Call nodes. However, to transform a standalone expression statement x.add_(y) into an assignment x = 
, we ideally need access to the statement container.

If the hook infrastructure passes cst.Call, we can only mutate the call itself. Replacing x.add_(y) with x = x.add(y) inside another expression is invalid syntax. Therefore, this logic assumes usage primarily in top-level expression statements or relies on the Rewriter’s ability to handle statement expansion if this returns a wrapper.

Current Strategy: We strip the underscore to make the call functional. If the Call is part of an Expression Statement (standalone), it effectively becomes a no-op output (x.add(y) computed but lost) unless assigned.

Refined Strategy: Since we can’t easily ascend to the statement level from a Call hook, we strip the _ to ensure the API mapping (e.g. torch.add) works. The user receives x.add(y), which is valid execution but discards result. WARNING: This is a limitation of Call-level hooks. Ideally, we flag this.

Implementation:

  1. Strip _ from method name: x.add_(y) -> x.add(y).

  2. If Context allows or we detect usage context, we might attempt assignment injection, but Call replacement with Assign is risky in nested contexts. However, standard PyTorch in-place ops x.add_(y) return x. So z = x.add_(y) -> z = x.add(y) is semantically correct conversion to functional. The only “Loss” is that x itself isn’t updated in the scope.