ml_switcheroo.core.rewriter.passes ================================== .. py:module:: ml_switcheroo.core.rewriter.passes .. autoapi-nested-parse:: Transformation Passes Package. Submodules ---------- .. toctree:: :maxdepth: 1 /api/ml_switcheroo/core/rewriter/passes/api/index /api/ml_switcheroo/core/rewriter/passes/auxiliary/index /api/ml_switcheroo/core/rewriter/passes/structure/index Classes ------- .. autoapisummary:: ml_switcheroo.core.rewriter.passes.StructuralPass ml_switcheroo.core.rewriter.passes.StructuralTransformer ml_switcheroo.core.rewriter.passes.ApiPass ml_switcheroo.core.rewriter.passes.ApiTransformer ml_switcheroo.core.rewriter.passes.AuxiliaryPass ml_switcheroo.core.rewriter.passes.AuxiliaryTransformer Package Contents ---------------- .. py:class:: StructuralPass Bases: :py:obj:`ml_switcheroo.core.rewriter.interface.RewriterPass` Pass responsible for modifying the structural scaffolding of the code. It transforms class definitions, function signatures, and type hints to match the idioms of the target framework. .. py:method:: transform(module: libcst.Module, context: ml_switcheroo.core.rewriter.context.RewriterContext) -> libcst.Module Executes the structural transformation. :param module: The source CST. :param context: Shared rewriter state. :returns: The transformed CST. .. py:class:: StructuralTransformer(context: ml_switcheroo.core.rewriter.context.RewriterContext) Bases: :py:obj:`libcst.CSTTransformer` LibCST Transformer encapsulating all structural logic. Maintains internal state for scope depth, annotation context, and target framework traits. .. py:attribute:: context .. py:property:: target_traits :type: ml_switcheroo.semantics.schema.StructuralTraits Lazily load structural traits for the target framework. .. py:method:: leave_Module(original_node: libcst.Module, updated_node: libcst.Module) -> libcst.Module Injects accumulated module-level preamble statements (e.g. imports, shim classes) requested by plugins during the rewrite. Flushes and clears the buffer to prevent double injection in subsequent passes. .. py:method:: visit_Annotation(node: libcst.Annotation) -> Optional[bool] Flag entry into a type annotation context. .. py:method:: leave_Annotation(original_node: libcst.Annotation, updated_node: libcst.Annotation) -> libcst.Annotation Flag exit from a type annotation context. .. py:method:: leave_Name(original_node: libcst.Name, updated_node: libcst.Name) -> libcst.BaseExpression Rewrite type names if inside an annotation. .. py:method:: leave_Attribute(original_node: libcst.Attribute, updated_node: libcst.Attribute) -> libcst.BaseExpression Rewrite dotted type attributes (e.g. torch.Tensor) if inside an annotation. This fixes issue where complex types like torch.Tensor fail to rewrite because leave_Name only handles the leaves individually without context. .. py:method:: visit_ClassDef(node: libcst.ClassDef) -> Optional[bool] Detect Module inheritance to set processing state. .. py:method:: leave_ClassDef(original_node: libcst.ClassDef, updated_node: libcst.ClassDef) -> Union[libcst.ClassDef, libcst.CSTNode] Rewrite class inheritance if necessary. .. py:method:: visit_FunctionDef(node: libcst.FunctionDef) -> Optional[bool] Push function context onto signature stack. .. py:method:: leave_FunctionDef(original_node: libcst.FunctionDef, updated_node: libcst.FunctionDef) -> libcst.FunctionDef Apply signature, name, and body writes to functions. .. py:class:: ApiPass Bases: :py:obj:`ml_switcheroo.core.rewriter.interface.RewriterPass` Transformation pass for rewiring API usage. Handles resolving function calls to Abstract Operations (The Hub) and then projecting them to the Target Framework (The Spoke). Also handles attribute renaming and stateful assignment tracking. .. py:method:: transform(module: libcst.Module, context: ml_switcheroo.core.rewriter.context.RewriterContext) -> libcst.Module Executes the API transformation logic. :param module: The source CST. :param context: Shared rewriter state. :returns: The transformed CST. .. py:class:: ApiTransformer(context: ml_switcheroo.core.rewriter.context.RewriterContext) Bases: :py:obj:`libcst.CSTTransformer` LibCST Transformer for API Logic. This class centralizes the logic for: - Resolving names/aliases. - Tracking scope/state. - Rewriting Calls, Attributes, and Assignments. .. py:attribute:: context .. py:property:: semantics :type: ml_switcheroo.semantics.manager.SemanticsManager Accessor for semantics manager. .. py:property:: config :type: ml_switcheroo.config.RuntimeConfig Accessor for runtime config. .. py:property:: source_fw :type: str Accessor for source framework key. .. py:property:: target_fw :type: str Accessor for target framework key. .. py:property:: strict_mode :type: bool Accessor for strict mode flag. .. py:property:: source_traits :type: ml_switcheroo.semantics.schema.StructuralTraits Lazily loads source framework traits. .. py:method:: check_version_constraints(min_v: Optional[str], max_v: Optional[str]) -> Optional[str] Checks if target version requirements are met. .. py:method:: leave_Module(original_node: libcst.Module, updated_node: libcst.Module) -> libcst.Module Injects accumulated module-level preamble statements if they haven't been flushed yet by a prior pass (like StructuralPass). We deduplicate based on string content. .. py:method:: visit_ClassDef(node: libcst.ClassDef) -> Optional[bool] Enter class scope and detect Module. .. py:method:: leave_ClassDef(original_node: libcst.ClassDef, updated_node: libcst.ClassDef) -> libcst.ClassDef Exit class scope. .. py:method:: visit_FunctionDef(node: libcst.FunctionDef) -> Optional[bool] Enter function scope. .. py:method:: leave_FunctionDef(original_node: libcst.FunctionDef, updated_node: libcst.FunctionDef) -> libcst.FunctionDef Exit function scope. Flush any pending preamble statements requested by plugins during this pass. Also apply any pending signature injections (arguments). .. py:method:: visit_SimpleStatementLine(node: libcst.SimpleStatementLine) -> Optional[bool] Reset statement-level error buffers. .. py:method:: leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) -> Union[libcst.SimpleStatementLine, libcst.FlattenSentinel] Check for errors generated by child expressions and wrap if needed. .. py:method:: visit_Import(node: libcst.Import) -> Optional[bool] Track import aliases. .. py:method:: visit_ImportFrom(node: libcst.ImportFrom) -> Optional[bool] Track from-import aliases. .. py:method:: leave_Assign(original_node: libcst.Assign, updated_node: libcst.Assign) -> libcst.Assign Handle assignment rewriting. 1. Track stateful initializations (e.g. self.layer = Linear...). 2. Unwrap functional returns (e.g. y, state = layer.apply...). .. py:method:: leave_Attribute(original_node: libcst.Attribute, updated_node: libcst.Attribute) -> libcst.BaseExpression Rewrites attributes and constants (e.g. torch.float32). Skips rewriting if the attribute looks like a function call (handled by leave_Call). .. py:method:: leave_Call(original_node: libcst.Call, updated_node: libcst.Call) -> Union[libcst.Call, libcst.BinaryOperation, libcst.UnaryOperation, libcst.CSTNode] Main entry point for function call rewriting. .. py:property:: ctx :type: Any Expose hook context for strategy invocation. .. py:class:: AuxiliaryPass Bases: :py:obj:`ml_switcheroo.core.rewriter.interface.RewriterPass` Pass dealing with auxiliary syntax constructs: decorators and control flow. .. py:method:: transform(module: libcst.Module, context: ml_switcheroo.core.rewriter.context.RewriterContext) -> libcst.Module Executes the auxiliary transformation logic. :param module: The source CST. :param context: Shared state. :returns: The transformed CST. .. py:class:: AuxiliaryTransformer(context: ml_switcheroo.core.rewriter.context.RewriterContext) Bases: :py:obj:`libcst.CSTTransformer` LibCST Transformer for auxiliary constructs. .. py:attribute:: context .. py:method:: visit_SimpleStatementLine(node: libcst.SimpleStatementLine) -> Optional[bool] Reset statement buffers. .. py:method:: leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) -> Union[libcst.SimpleStatementLine, libcst.FlattenSentinel] Process statement errors. .. py:method:: leave_Decorator(original_node: libcst.Decorator, updated_node: libcst.Decorator) -> Union[libcst.Decorator, libcst.RemovalSentinel] Rewrites decorators. Logic: 1. Resolve decorator name (e.g. `torch.jit.script`). 2. Lookup semantics. 3. If target variant is None -> Remove. 4. If target variant has API -> Rename. .. py:method:: leave_For(original_node: libcst.For, updated_node: libcst.For) -> Union[libcst.For, libcst.CSTNode, libcst.FlattenSentinel] Processes 'for' loops for safety checks and unrolling.