ml_switcheroo.plugins.static_unroll¶

Plugin for Statically Unrolling Loops.

JAX and XLA-based frameworks generally require functional loops (scan, fori_loop) which are difficult to automatically generate from imperative Python for loops due to complex variable scoping and state containment rules.

However, many neural network definitions use loops over fixed constants (e.g., for i in range(3): layer(x)). Unrolling these provides a valid, optimizable graph structure without requiring complex functional rewrite logic.

Usage¶

This plugin registers the hook transform_for_loop_static. It is invoked by the ControlFlowMixin prior to general loop safety scanners.

Process¶

  1. Analysis: Detects for i in range(N) where N is a static integer literal.

  2. Safety: Checks if N is within a reasonable limit to prevent code explosion.

  3. Expansion: Duplicates the loop body N times.

  4. Substitution: Replaces usages of the loop variable (i) with the literal integer for that iteration (0, 1, etc.).

  5. Output: Returns a cst.FlattenSentinel containing the list of statements.

Classes¶

LoopVarReplacer

Helper visitor to replace loop variable instances with a constant integer.

Functions¶

unroll_static_loops(→ Union[libcst.For, ...)

Hook: Unrolls loops with static ranges.

Module Contents¶

class ml_switcheroo.plugins.static_unroll.LoopVarReplacer(var_name: str, value: int)¶

Bases: libcst.CSTTransformer

Helper visitor to replace loop variable instances with a constant integer.

var_name¶

The name of the loop variable to target (e.g., ‘i’).

Type:

str

value¶

The integer constant to substitute (e.g., 0).

Type:

int

var_name¶
value¶
leave_Name(original_node: libcst.Name, updated_node: libcst.Name) → libcst.BaseExpression¶

Replace occurences of variables matching var_name with Integer(value).

ml_switcheroo.plugins.static_unroll.unroll_static_loops(node: libcst.For, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.For | libcst.FlattenSentinel¶

Hook: Unrolls loops with static ranges.

Triggers:

Invoked by ControlFlowMixin via the transform_for_loop_static key.

Transformation:
Input:
for i in range(2):

x = f(x, i)

Output:

x = f(x, 0) x = f(x, 1)

Parameters:
  • node (cst.For) – The original For loop node.

  • ctx (HookContext) – The execution context (unused in this logic but required by protocol).

Returns:

  • FlattenSentinel containing unrolled statements if successful.

  • Original node if the loop is dynamic or too large.

Return type:

Union[cst.For, cst.FlattenSentinel]