ml_switcheroo.plugins.static_unroll¶
Plugin for Statically Unrolling Loops.
JAX and XLA-based frameworks generally require functional loops (scan, fori_loop)
which are difficult to automatically generate from imperative Python for loops
due to complex variable scoping and state containment rules.
However, many neural network definitions use loops over fixed constants
(e.g., for i in range(3): layer(x)). Unrolling these provides a valid,
optimizable graph structure without requiring complex functional rewrite logic.
Usage¶
This plugin registers the hook transform_for_loop_static. It is invoked
by the ControlFlowMixin prior to general loop safety scanners.
Process¶
Analysis: Detects
for i in range(N)where N is a static integer literal.Safety: Checks if N is within a reasonable limit to prevent code explosion.
Expansion: Duplicates the loop body N times.
Substitution: Replaces usages of the loop variable (
i) with the literal integer for that iteration (0,1, etc.).Output: Returns a
cst.FlattenSentinelcontaining the list of statements.
Classes¶
Helper visitor to replace loop variable instances with a constant integer. |
Functions¶
|
Hook: Unrolls loops with static ranges. |
Module Contents¶
- class ml_switcheroo.plugins.static_unroll.LoopVarReplacer(var_name: str, value: int)¶
Bases:
libcst.CSTTransformerHelper visitor to replace loop variable instances with a constant integer.
- var_name¶
The name of the loop variable to target (e.g., ‘i’).
- Type:
str
- value¶
The integer constant to substitute (e.g., 0).
- Type:
int
- var_name¶
- value¶
- leave_Name(original_node: libcst.Name, updated_node: libcst.Name) libcst.BaseExpression¶
Replace occurences of variables matching var_name with Integer(value).
- ml_switcheroo.plugins.static_unroll.unroll_static_loops(node: libcst.For, ctx: ml_switcheroo.core.hooks.HookContext) libcst.For | libcst.FlattenSentinel¶
Hook: Unrolls loops with static ranges.
- Triggers:
Invoked by
ControlFlowMixinvia thetransform_for_loop_statickey.- Transformation:
- Input:
- for i in range(2):
x = f(x, i)
- Output:
x = f(x, 0) x = f(x, 1)
- Parameters:
node (cst.For) – The original For loop node.
ctx (HookContext) – The execution context (unused in this logic but required by protocol).
- Returns:
FlattenSentinelcontaining unrolled statements if successful.Original
nodeif the loop is dynamic or too large.
- Return type:
Union[cst.For, cst.FlattenSentinel]