ml_switcheroo.plugins.loop_unroll

Plugin for Unrolling Loops (Functional Control Flow Enforcement).

This module handles the structural transformation of Python control flow into functional primitives required by XLA-based frameworks. Unlike PyTorch/TensorFlow Eager, which support imperative Python loops, frameworks targeting XLA (like JAX) require loops to be expressed as functional operators (scan or fori_loop) to be compiled.

The strategy is Safety-First:

  1. Analysis: Inspects for loops.

  2. Trait Check: Determines if the target framework requires functional control flow.

  3. Handling: Wraps imperative loops in an EscapeHatch warning rather than attempting unsafe auto-conversion (solving the “carry state” problem is often undecidable).

Functions

transform_loops(→ Union[libcst.For, ...)

Plugin Hook: Transforms or Flags for loops for functional compliance.

Module Contents

ml_switcheroo.plugins.loop_unroll.transform_loops(node: libcst.For, ctx: ml_switcheroo.core.hooks.HookContext) libcst.For | libcst.FlattenSentinel[source]

Plugin Hook: Transforms or Flags for loops for functional compliance.

Triggered by the ControlFlowMixin when visiting For nodes.

Strategy:

  • Check ctx.plugin_traits.requires_functional_control_flow.

  • If False: Pass through (Imperative loops are valid).

  • If True:
    • Analyze the iterator.

    • If range(): Warn that jax.lax.fori_loop (or equivalent) is required.

    • Else: Warn that scan is required.

    • Wrap in EscapeHatch to prevent compilation errors in the target.

Parameters:
  • node – The original CST For loop node.

  • ctx – Hook context containing framework configuration traits.

Returns:

The transformed node or an EscapeHatch sentinel.