ml_switcheroo.plugins.loop_unroll¶
Plugin for Unrolling Loops (Functional Control Flow Enforcement).
This module handles the structural transformation of Python control flow into functional primitives required by XLA-based frameworks. Unlike PyTorch/TensorFlow Eager, which support imperative Python loops, frameworks targeting XLA (like JAX) require loops to be expressed as functional operators (scan or fori_loop) to be compiled.
The strategy is Safety-First:
Analysis: Inspects for loops.
Trait Check: Determines if the target framework requires functional control flow.
Handling: Wraps imperative loops in an EscapeHatch warning rather than attempting unsafe auto-conversion (solving the “carry state” problem is often undecidable).
Functions¶
|
Plugin Hook: Transforms or Flags for loops for functional compliance. |
Module Contents¶
- ml_switcheroo.plugins.loop_unroll.transform_loops(node: libcst.For, ctx: ml_switcheroo.core.hooks.HookContext) libcst.For | libcst.FlattenSentinel[source]¶
Plugin Hook: Transforms or Flags for loops for functional compliance.
Triggered by the ControlFlowMixin when visiting For nodes.
Strategy:
Check ctx.plugin_traits.requires_functional_control_flow.
If False: Pass through (Imperative loops are valid).
- If True:
Analyze the iterator.
If range(): Warn that jax.lax.fori_loop (or equivalent) is required.
Else: Warn that scan is required.
Wrap in EscapeHatch to prevent compilation errors in the target.
- Parameters:
node – The original CST For loop node.
ctx – Hook context containing framework configuration traits.
- Returns:
The transformed node or an EscapeHatch sentinel.