ml_switcheroo.plugins.loop_unroll

Plugin for Unrolling Loops (JAX Scan/Fori_loop).

This module handles the structural transformation of Python control flow into JAX-compatible primitives. Unlike PyTorch, which supports imperative Python loops, JAX requires loops to be expressed as functional operators (jax.lax.scan or jax.lax.fori_loop) to enable XLA compilation (JIT).

Logic: 1. Analysis: Inspects for loops to determine if they iterate over a range (candidates for fori_loop)

or an iterable (candidates for scan).

  1. Safety Check: Because automated loop conversion requires solving the “Carry State” problem (identifying which variables are mutated across iterations), this plugin currently defaults to a Safety-First strategy.

  2. Transformation: Instead of hallucinating broken JAX code, it wraps the loop in an EscapeHatch. This preserves the original logic while explicitly flagging it as a blocker for JIT compliance, guiding the user to manually refactor it into a functional pattern.

Functions

transform_loops(→ Union[libcst.For, ...)

Plugin Hook: Transforms for loops for JAX compliance.

Module Contents

ml_switcheroo.plugins.loop_unroll.transform_loops(node: libcst.For, ctx: ml_switcheroo.core.hooks.HookContext) libcst.For | libcst.FlattenSentinel

Plugin Hook: Transforms for loops for JAX compliance.

Triggered by the ControlFlowMixin when visiting For nodes.

Strategy:
  • If Target != JAX: Pass through (Python loops are valid in Torch/TF Eager).

  • If Target == JAX:
    • Analyze the iterator.

    • If range(): Flag as candidates for jax.lax.fori_loop.

    • Else: Flag as candidates for jax.lax.scan.

    • Wrap in EscapeHatch to prevent compilation errors in JAX/XLA.

Parameters:
  • node – The original CST For loop node.

  • ctx – Hook context containing framework targets.

Returns:

The transformed node or an EscapeHatch sentinel.