ml_switcheroo.core.rewriter =========================== .. py:module:: ml_switcheroo.core.rewriter .. autoapi-nested-parse:: Rewriter Package. This package provides the `PivotRewriter` class, composed of several mixins to handle specific aspects of the AST transformation: - Structure: Class and Function definitions. - Calls: Function invocations. - Contexts: With blocks. - Attributes: Attribute access. - Normalization: Argument mapping. - ControlFlow: Loop and branching logic. - Decorators: Decorator handling. Submodules ---------- .. toctree:: :maxdepth: 1 /api/ml_switcheroo/core/rewriter/attributes/index /api/ml_switcheroo/core/rewriter/base/index /api/ml_switcheroo/core/rewriter/calls/index /api/ml_switcheroo/core/rewriter/contexts/index /api/ml_switcheroo/core/rewriter/control_flow/index /api/ml_switcheroo/core/rewriter/decorators/index /api/ml_switcheroo/core/rewriter/initializers/index /api/ml_switcheroo/core/rewriter/normalization/index /api/ml_switcheroo/core/rewriter/structure/index /api/ml_switcheroo/core/rewriter/structure_class/index /api/ml_switcheroo/core/rewriter/structure_func/index /api/ml_switcheroo/core/rewriter/structure_types/index /api/ml_switcheroo/core/rewriter/types/index Classes ------- .. autoapisummary:: ml_switcheroo.core.rewriter.CallMixin ml_switcheroo.core.rewriter.NormalizationMixin ml_switcheroo.core.rewriter.AttributeMixin ml_switcheroo.core.rewriter.StructureMixin ml_switcheroo.core.rewriter.DecoratorMixin ml_switcheroo.core.rewriter.ControlFlowMixin ml_switcheroo.core.rewriter.ContextMixin ml_switcheroo.core.rewriter.BaseRewriter ml_switcheroo.core.rewriter.PivotRewriter Package Contents ---------------- .. py:class:: CallMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig) Bases: :py:obj:`ml_switcheroo.core.rewriter.normalization.NormalizationMixin`, :py:obj:`ml_switcheroo.core.rewriter.base.BaseRewriter` Mixin for transforming Call nodes and unpacking Assignments. Responsible for: 1. Handling functional `apply` patterns (Flax). 2. Lifecycle method stripping (`.to()`, `.cuda()`). 3. Plugin dispatch. 4. Standard API pivoting (Lookup -> Normalize -> Rewrite). 5. Output Transformation (Indexing, Casting). .. py:method:: leave_Assign(original_node: libcst.Assign, updated_node: libcst.Assign) -> libcst.Assign Handles assignment unwrapping for Functional -> OOP transitions. Scenario: `y, updates = layer.apply(vars, x)` Target: `y = layer(x)` (NNX/Torch style) .. py:method:: leave_Call(original: libcst.Call, updated: libcst.Call) -> Union[libcst.Call, libcst.BinaryOperation, libcst.UnaryOperation, libcst.CSTNode] Rewrites function calls with detailed Trace Logging. Implements Logic 4: Layer Init State Threading. .. py:class:: NormalizationMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig) Bases: :py:obj:`ml_switcheroo.core.rewriter.base.BaseRewriter` Mixin class providing argument normalization and operator transformation logic. .. py:class:: AttributeMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig) Bases: :py:obj:`ml_switcheroo.core.rewriter.base.BaseRewriter` Mixin for transforming attributes and tracking assignments. .. py:method:: leave_Assign(original_node: libcst.Assign, updated_node: libcst.Assign) -> libcst.Assign Track stateful assignments (e.g. self.layer = Linear(...)) to determine if an object variable requires special handling later. .. py:method:: leave_Attribute(original: libcst.Attribute, updated: libcst.Attribute) -> libcst.BaseExpression Visits attributes (e.g. torch.float32). .. py:class:: StructureMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig) Bases: :py:obj:`ml_switcheroo.core.rewriter.structure_class.ClassStructureMixin`, :py:obj:`ml_switcheroo.core.rewriter.structure_func.FuncStructureMixin`, :py:obj:`ml_switcheroo.core.rewriter.structure_types.TypeStructureMixin` Composite mixin for all structural rewriting tasks. Inherits from granular mixins to assemble the full feature set. .. py:class:: DecoratorMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig) Bases: :py:obj:`ml_switcheroo.core.rewriter.base.BaseRewriter` Mixin for transforming Decorator nodes. Part of PivotRewriter. .. py:method:: leave_Decorator(original_node: libcst.Decorator, updated_node: libcst.Decorator) -> Union[libcst.Decorator, libcst.RemovalSentinel] Processes decorators attached to functions or classes. Logic: 1. Identifies the decorator name from `original_node` to ensure we key off the Source Framework API, even if `CallMixin` modified the children in `updated_node`. 2. Looks up the semantic definition. 3. If the target variant is explicitly `null` (None), removes the decorator. 4. If the target variant specifies a new API, rewrites the name. .. py:class:: ControlFlowMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig) Bases: :py:obj:`ml_switcheroo.core.rewriter.base.BaseRewriter` Mixin for visiting Control Flow nodes (For, While, If). .. py:method:: leave_For(original_node: libcst.For, updated_node: libcst.For) -> Union[libcst.For, libcst.CSTNode] Invokes loop transformation logic. Implements a priority chain: 1. **Static Unroll**: Checks ``transform_for_loop_static``. If the loop index is constant, unrolls it (Optimization). 2. **General Transform**: Checks ``transform_for_loop``. Handles general case logic or applies safety warnings (e.g., Escape Hatch for JAX). :param original_node: The node before transformation. :type original_node: cst.For :param updated_node: The node after child visitors have run. :type updated_node: cst.For :returns: The transformed node (potentially a FlattenSentinel of unrolled statements) or the original node if untouched. :rtype: Union[cst.For, cst.CSTNode] .. py:class:: ContextMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig) Bases: :py:obj:`ml_switcheroo.core.rewriter.base.BaseRewriter` Mixin for transforming `With` blocks. .. py:method:: leave_With(original_node: libcst.With, updated_node: libcst.With) -> Union[libcst.With, libcst.FlattenSentinel] Processes 'with' statements. Logic: 1. Iterate over `with` items (expressions). 2. Identify if the expression corresponds to a `CONTEXT` OpType in semantics. 3. Check transformation logic: - If `strip_context`, identifying marker is found, lift the body. - Otherwise, allow `CallMixin` (which runs on children) to have already renamed the API. .. py:class:: BaseRewriter(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig) Bases: :py:obj:`libcst.CSTTransformer` The base class for AST transformation traversal. Provides common utilities for scope tracking, alias resolution, and error bubbling, which are utilized by the specific Mixins (CallMixin, StructureMixin, etc.). .. py:attribute:: semantics .. py:attribute:: config .. py:attribute:: source_fw :value: '' .. py:attribute:: target_fw :value: '' .. py:attribute:: strict_mode .. py:attribute:: ctx .. py:method:: leave_Module(original_node: libcst.Module, updated_node: libcst.Module) -> libcst.Module Injects module-level preambles (e.g. Shim classes) requested by plugins. Ensures injection happens after docstrings to maintain valid Python help text. :param original_node: Logic before transformation. :param updated_node: Logic after transformation. :returns: The module with injected preambles. :rtype: cst.Module .. py:method:: visit_Import(node: libcst.Import) -> Optional[bool] Scans ``import ...`` statements to populate the alias map. Example: ``import torch.nn as nn`` -> ``_alias_map['nn'] = 'torch.nn'``. :param node: Import statement node. :returns: False to stop traversal of children. :rtype: Optional[bool] .. py:method:: visit_ImportFrom(node: libcst.ImportFrom) -> Optional[bool] Scans ``from ... import ...`` statements to populate the alias map. Example: ``from torch import nn`` -> ``_alias_map['nn'] = 'torch.nn'``. :param node: ImportFrom statement node. :returns: False to stop traversal of children. :rtype: Optional[bool] .. py:method:: visit_SimpleStatementLine(node: libcst.SimpleStatementLine) -> Optional[bool] Resets error tracking at the start of each line. Errors bubble up from children (Expressions) to this Statement handler. :param node: The statement line node. :returns: True to continue traversal. :rtype: Optional[bool] .. py:method:: leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) -> Union[libcst.SimpleStatementLine, libcst.FlattenSentinel] Handles error bubbling from expression rewrites. If errors occurred during processing of this line's children (e.g. failing to rewrite a function call), wrap the line in an ``EscapeHatch``. Prioritizes errors (reverting to original code) over warnings (using updated code). :param original_node: The node before children were visited. :param updated_node: The node after children transformation. :returns: The resulted node (possibly wrapped with comments). :rtype: Union[cst.SimpleStatementLine, cst.FlattenSentinel] .. py:class:: PivotRewriter(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig) Bases: :py:obj:`contexts.ContextMixin`, :py:obj:`control_flow.ControlFlowMixin`, :py:obj:`decorators.DecoratorMixin`, :py:obj:`calls.CallMixin`, :py:obj:`normalization.NormalizationMixin`, :py:obj:`attributes.AttributeMixin`, :py:obj:`structure.StructureMixin`, :py:obj:`base.BaseRewriter` The main AST transformer for ml-switcheroo. Inherits functionality from component identifiers (Mixins) and the base transformer logic. This class is the entry point for the ASTEngine.