ml_switcheroo.plugins.static_unroll =================================== .. py:module:: ml_switcheroo.plugins.static_unroll .. autoapi-nested-parse:: Plugin for Statically Unrolling Loops. JAX and XLA-based frameworks generally require functional loops (``scan``, ``fori_loop``) which are difficult to automatically generate from imperative Python ``for`` loops due to complex variable scoping and state containment rules. However, many neural network definitions use loops over **fixed constants** (e.g., ``for i in range(3): layer(x)``). Unrolling these provides a valid, optimizable graph structure without requiring complex functional rewrite logic. Usage ----- This plugin registers the hook ``transform_for_loop_static``. It is invoked by the ``ControlFlowMixin`` prior to general loop safety scanners. Process ------- 1. **Analysis**: Detects ``for i in range(N)`` where N is a static integer literal. 2. **Safety**: Checks if N is within a reasonable limit to prevent code explosion. 3. **Expansion**: Duplicates the loop body N times. 4. **Substitution**: Replaces usages of the loop variable (``i``) with the literal integer for that iteration (``0``, ``1``, etc.). 5. **Output**: Returns a ``cst.FlattenSentinel`` containing the list of statements. Classes ------- .. autoapisummary:: ml_switcheroo.plugins.static_unroll.LoopVarReplacer Functions --------- .. autoapisummary:: ml_switcheroo.plugins.static_unroll.unroll_static_loops Module Contents --------------- .. py:class:: LoopVarReplacer(var_name: str, value: int) Bases: :py:obj:`libcst.CSTTransformer` Helper visitor to replace loop variable instances with a constant integer. .. attribute:: var_name The name of the loop variable to target (e.g., 'i'). :type: str .. attribute:: value The integer constant to substitute (e.g., 0). :type: int .. py:attribute:: var_name .. py:attribute:: value .. py:method:: leave_Name(original_node: libcst.Name, updated_node: libcst.Name) -> libcst.BaseExpression Replace occurences of variables matching `var_name` with `Integer(value)`. .. py:function:: unroll_static_loops(node: libcst.For, ctx: ml_switcheroo.core.hooks.HookContext) -> Union[libcst.For, libcst.FlattenSentinel] Hook: Unrolls loops with static ranges. Triggers: Invoked by ``ControlFlowMixin`` via the ``transform_for_loop_static`` key. Transformation: Input: for i in range(2): x = f(x, i) Output: x = f(x, 0) x = f(x, 1) :param node: The original For loop node. :type node: cst.For :param ctx: The execution context (unused in this logic but required by protocol). :type ctx: HookContext :returns: - ``FlattenSentinel`` containing unrolled statements if successful. - Original ``node`` if the loop is dynamic or too large. :rtype: Union[cst.For, cst.FlattenSentinel]