ml_switcheroo.plugins.loop_unroll ================================= .. py:module:: ml_switcheroo.plugins.loop_unroll .. autoapi-nested-parse:: Plugin for Unrolling Loops (JAX Scan/Fori_loop). This module handles the structural transformation of Python control flow into JAX-compatible primitives. Unlike PyTorch, which supports imperative Python loops, JAX requires loops to be expressed as functional operators (`jax.lax.scan` or `jax.lax.fori_loop`) to enable XLA compilation (JIT). Logic: 1. **Analysis**: Inspects `for` loops to determine if they iterate over a `range` (candidates for `fori_loop`) or an iterable (candidates for `scan`). 2. **Safety Check**: Because automated loop conversion requires solving the "Carry State" problem (identifying which variables are mutated across iterations), this plugin currently defaults to a **Safety-First** strategy. 3. **Transformation**: Instead of hallucinating broken JAX code, it wraps the loop in an `EscapeHatch`. This preserves the original logic while explicitly flagging it as a blocker for JIT compliance, guiding the user to manually refactor it into a functional pattern. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.loop_unroll.transform_loops Module Contents --------------- .. py:function:: transform_loops(node: libcst.For, ctx: ml_switcheroo.core.hooks.HookContext) -> Union[libcst.For, libcst.FlattenSentinel] Plugin Hook: Transforms `for` loops for JAX compliance. Triggered by the `ControlFlowMixin` when visiting `For` nodes. Strategy: - If Target != JAX: Pass through (Python loops are valid in Torch/TF Eager). - If Target == JAX: - Analyze the iterator. - If `range()`: Flag as candidates for `jax.lax.fori_loop`. - Else: Flag as candidates for `jax.lax.scan`. - Wrap in `EscapeHatch` to prevent compilation errors in JAX/XLA. :param node: The original CST For loop node. :param ctx: Hook context containing framework targets. :returns: The transformed node or an EscapeHatch sentinel.