ml_switcheroo.core.rewriter.passes¶

Transformation Passes Package.

Submodules¶

Classes¶

StructuralPass

Pass responsible for modifying the structural scaffolding of the code.

StructuralTransformer

LibCST Transformer encapsulating all structural logic.

ApiPass

Transformation pass for rewiring API usage.

ApiTransformer

LibCST Transformer for API Logic.

AuxiliaryPass

Pass dealing with auxiliary syntax constructs: decorators and control flow.

AuxiliaryTransformer

LibCST Transformer for auxiliary constructs.

Package Contents¶

class ml_switcheroo.core.rewriter.passes.StructuralPass[source]¶

Bases: ml_switcheroo.core.rewriter.interface.RewriterPass

Pass responsible for modifying the structural scaffolding of the code.

It transforms class definitions, function signatures, and type hints to match the idioms of the target framework.

transform(module: libcst.Module, context: ml_switcheroo.core.rewriter.context.RewriterContext) → libcst.Module[source]¶

Executes the structural transformation.

Parameters:
  • module – The source CST.

  • context – Shared rewriter state.

Returns:

The transformed CST.

class ml_switcheroo.core.rewriter.passes.StructuralTransformer(context: ml_switcheroo.core.rewriter.context.RewriterContext)[source]¶

Bases: libcst.CSTTransformer

LibCST Transformer encapsulating all structural logic.

Maintains internal state for scope depth, annotation context, and target framework traits.

context¶
property target_traits: ml_switcheroo.semantics.schema.StructuralTraits¶

Lazily load structural traits for the target framework.

leave_Module(original_node: libcst.Module, updated_node: libcst.Module) → libcst.Module[source]¶

Injects accumulated module-level preamble statements (e.g. imports, shim classes) requested by plugins during the rewrite. Flushes and clears the buffer to prevent double injection in subsequent passes.

visit_Annotation(node: libcst.Annotation) → bool | None[source]¶

Flag entry into a type annotation context.

leave_Annotation(original_node: libcst.Annotation, updated_node: libcst.Annotation) → libcst.Annotation[source]¶

Flag exit from a type annotation context.

leave_Name(original_node: libcst.Name, updated_node: libcst.Name) → libcst.BaseExpression[source]¶

Rewrite type names if inside an annotation.

leave_Attribute(original_node: libcst.Attribute, updated_node: libcst.Attribute) → libcst.BaseExpression[source]¶

Rewrite dotted type attributes (e.g. torch.Tensor) if inside an annotation. This fixes issue where complex types like torch.Tensor fail to rewrite because leave_Name only handles the leaves individually without context.

visit_ClassDef(node: libcst.ClassDef) → bool | None[source]¶

Detect Module inheritance to set processing state.

leave_ClassDef(original_node: libcst.ClassDef, updated_node: libcst.ClassDef) → libcst.ClassDef | libcst.CSTNode[source]¶

Rewrite class inheritance if necessary.

visit_FunctionDef(node: libcst.FunctionDef) → bool | None[source]¶

Push function context onto signature stack.

leave_FunctionDef(original_node: libcst.FunctionDef, updated_node: libcst.FunctionDef) → libcst.FunctionDef[source]¶

Apply signature, name, and body writes to functions.

class ml_switcheroo.core.rewriter.passes.ApiPass[source]¶

Bases: ml_switcheroo.core.rewriter.interface.RewriterPass

Transformation pass for rewiring API usage.

Handles resolving function calls to Abstract Operations (The Hub) and then projecting them to the Target Framework (The Spoke). Also handles attribute renaming and stateful assignment tracking.

transform(module: libcst.Module, context: ml_switcheroo.core.rewriter.context.RewriterContext) → libcst.Module[source]¶

Executes the API transformation logic.

Parameters:
  • module – The source CST.

  • context – Shared rewriter state.

Returns:

The transformed CST.

class ml_switcheroo.core.rewriter.passes.ApiTransformer(context: ml_switcheroo.core.rewriter.context.RewriterContext)[source]¶

Bases: libcst.CSTTransformer

LibCST Transformer for API Logic.

This class centralizes the logic for: - Resolving names/aliases. - Tracking scope/state. - Rewriting Calls, Attributes, and Assignments.

context¶
property semantics: ml_switcheroo.semantics.manager.SemanticsManager¶

Accessor for semantics manager.

property config: ml_switcheroo.config.RuntimeConfig¶

Accessor for runtime config.

property source_fw: str¶

Accessor for source framework key.

property target_fw: str¶

Accessor for target framework key.

property strict_mode: bool¶

Accessor for strict mode flag.

property source_traits: ml_switcheroo.semantics.schema.StructuralTraits¶

Lazily loads source framework traits.

check_version_constraints(min_v: str | None, max_v: str | None) → str | None[source]¶

Checks if target version requirements are met.

leave_Module(original_node: libcst.Module, updated_node: libcst.Module) → libcst.Module[source]¶

Injects accumulated module-level preamble statements if they haven’t been flushed yet by a prior pass (like StructuralPass). We deduplicate based on string content.

visit_ClassDef(node: libcst.ClassDef) → bool | None[source]¶

Enter class scope and detect Module.

leave_ClassDef(original_node: libcst.ClassDef, updated_node: libcst.ClassDef) → libcst.ClassDef[source]¶

Exit class scope.

visit_FunctionDef(node: libcst.FunctionDef) → bool | None[source]¶

Enter function scope.

leave_FunctionDef(original_node: libcst.FunctionDef, updated_node: libcst.FunctionDef) → libcst.FunctionDef[source]¶

Exit function scope. Flush any pending preamble statements requested by plugins during this pass. Also apply any pending signature injections (arguments).

visit_SimpleStatementLine(node: libcst.SimpleStatementLine) → bool | None[source]¶

Reset statement-level error buffers.

leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) → libcst.SimpleStatementLine | libcst.FlattenSentinel[source]¶

Check for errors generated by child expressions and wrap if needed.

visit_Import(node: libcst.Import) → bool | None[source]¶

Track import aliases.

visit_ImportFrom(node: libcst.ImportFrom) → bool | None[source]¶

Track from-import aliases.

leave_Assign(original_node: libcst.Assign, updated_node: libcst.Assign) → libcst.Assign[source]¶

Handle assignment rewriting. 1. Track stateful initializations (e.g. self.layer = Linear…). 2. Unwrap functional returns (e.g. y, state = layer.apply…).

leave_Attribute(original_node: libcst.Attribute, updated_node: libcst.Attribute) → libcst.BaseExpression[source]¶

Rewrites attributes and constants (e.g. torch.float32). Skips rewriting if the attribute looks like a function call (handled by leave_Call).

leave_Call(original_node: libcst.Call, updated_node: libcst.Call) → libcst.Call | libcst.BinaryOperation | libcst.UnaryOperation | libcst.CSTNode[source]¶

Main entry point for function call rewriting.

property ctx: Any¶

Expose hook context for strategy invocation.

class ml_switcheroo.core.rewriter.passes.AuxiliaryPass[source]¶

Bases: ml_switcheroo.core.rewriter.interface.RewriterPass

Pass dealing with auxiliary syntax constructs: decorators and control flow.

transform(module: libcst.Module, context: ml_switcheroo.core.rewriter.context.RewriterContext) → libcst.Module[source]¶

Executes the auxiliary transformation logic.

Parameters:
  • module – The source CST.

  • context – Shared state.

Returns:

The transformed CST.

class ml_switcheroo.core.rewriter.passes.AuxiliaryTransformer(context: ml_switcheroo.core.rewriter.context.RewriterContext)[source]¶

Bases: libcst.CSTTransformer

LibCST Transformer for auxiliary constructs.

context¶
visit_SimpleStatementLine(node: libcst.SimpleStatementLine) → bool | None[source]¶

Reset statement buffers.

leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) → libcst.SimpleStatementLine | libcst.FlattenSentinel[source]¶

Process statement errors.

leave_Decorator(original_node: libcst.Decorator, updated_node: libcst.Decorator) → libcst.Decorator | libcst.RemovalSentinel[source]¶

Rewrites decorators.

Logic: 1. Resolve decorator name (e.g. torch.jit.script). 2. Lookup semantics. 3. If target variant is None -> Remove. 4. If target variant has API -> Rename.

leave_For(original_node: libcst.For, updated_node: libcst.For) → libcst.For | libcst.CSTNode | libcst.FlattenSentinel[source]¶

Processes ‘for’ loops for safety checks and unrolling.