ml_switcheroo.plugins.loop_unroll¶
Plugin for Unrolling Loops (JAX Scan/Fori_loop).
This module handles the structural transformation of Python control flow into JAX-compatible primitives. Unlike PyTorch, which supports imperative Python loops, JAX requires loops to be expressed as functional operators (jax.lax.scan or jax.lax.fori_loop) to enable XLA compilation (JIT).
Logic: 1. Analysis: Inspects for loops to determine if they iterate over a range (candidates for fori_loop)
or an iterable (candidates for scan).
Safety Check: Because automated loop conversion requires solving the “Carry State” problem (identifying which variables are mutated across iterations), this plugin currently defaults to a Safety-First strategy.
Transformation: Instead of hallucinating broken JAX code, it wraps the loop in an EscapeHatch. This preserves the original logic while explicitly flagging it as a blocker for JIT compliance, guiding the user to manually refactor it into a functional pattern.
Functions¶
|
Plugin Hook: Transforms for loops for JAX compliance. |
Module Contents¶
- ml_switcheroo.plugins.loop_unroll.transform_loops(node: libcst.For, ctx: ml_switcheroo.core.hooks.HookContext) libcst.For | libcst.FlattenSentinel¶
Plugin Hook: Transforms for loops for JAX compliance.
Triggered by the ControlFlowMixin when visiting For nodes.
- Strategy:
If Target != JAX: Pass through (Python loops are valid in Torch/TF Eager).
- If Target == JAX:
Analyze the iterator.
If range(): Flag as candidates for jax.lax.fori_loop.
Else: Flag as candidates for jax.lax.scan.
Wrap in EscapeHatch to prevent compilation errors in JAX/XLA.
- Parameters:
node – The original CST For loop node.
ctx – Hook context containing framework targets.
- Returns:
The transformed node or an EscapeHatch sentinel.