ml_switcheroo.core.rewriter.passes.structure¶

Structural Rewriting Pass.

This module consolidates all structural transformation logic, including: 1. Class Inheritance Rewriting: Swapping framework base classes (e.g., torch.nn.Module -> flax.nnx.Module). 2. Function Signature Rewriting: Injecting or stripping state/context arguments (e.g., rngs, ctx). 3. Method Renaming: Mapping lifecycle methods (e.g., forward -> __call__). 4. Body Injection: Handling super().__init__ calls and preamble injection. 5. Type Annotation Rewriting: Mapping framework-specific types (e.g., torch.Tensor -> jax.Array).

Classes¶

StructuralPass

Pass responsible for modifying the structural scaffolding of the code.

StructuralTransformer

LibCST Transformer encapsulating all structural logic.

Module Contents¶

class ml_switcheroo.core.rewriter.passes.structure.StructuralPass[source]¶

Bases: ml_switcheroo.core.rewriter.interface.RewriterPass

Pass responsible for modifying the structural scaffolding of the code.

It transforms class definitions, function signatures, and type hints to match the idioms of the target framework.

transform(module: libcst.Module, context: ml_switcheroo.core.rewriter.context.RewriterContext) → libcst.Module[source]¶

Executes the structural transformation.

Parameters:
  • module – The source CST.

  • context – Shared rewriter state.

Returns:

The transformed CST.

class ml_switcheroo.core.rewriter.passes.structure.StructuralTransformer(context: ml_switcheroo.core.rewriter.context.RewriterContext)[source]¶

Bases: libcst.CSTTransformer

LibCST Transformer encapsulating all structural logic.

Maintains internal state for scope depth, annotation context, and target framework traits.

context¶
property target_traits: ml_switcheroo.semantics.schema.StructuralTraits¶

Lazily load structural traits for the target framework.

leave_Module(original_node: libcst.Module, updated_node: libcst.Module) → libcst.Module[source]¶

Injects accumulated module-level preamble statements (e.g. imports, shim classes) requested by plugins during the rewrite. Flushes and clears the buffer to prevent double injection in subsequent passes.

visit_Annotation(node: libcst.Annotation) → bool | None[source]¶

Flag entry into a type annotation context.

leave_Annotation(original_node: libcst.Annotation, updated_node: libcst.Annotation) → libcst.Annotation[source]¶

Flag exit from a type annotation context.

leave_Name(original_node: libcst.Name, updated_node: libcst.Name) → libcst.BaseExpression[source]¶

Rewrite type names if inside an annotation.

leave_Attribute(original_node: libcst.Attribute, updated_node: libcst.Attribute) → libcst.BaseExpression[source]¶

Rewrite dotted type attributes (e.g. torch.Tensor) if inside an annotation. This fixes issue where complex types like torch.Tensor fail to rewrite because leave_Name only handles the leaves individually without context.

visit_ClassDef(node: libcst.ClassDef) → bool | None[source]¶

Detect Module inheritance to set processing state.

leave_ClassDef(original_node: libcst.ClassDef, updated_node: libcst.ClassDef) → libcst.ClassDef | libcst.CSTNode[source]¶

Rewrite class inheritance if necessary.

visit_FunctionDef(node: libcst.FunctionDef) → bool | None[source]¶

Push function context onto signature stack.

leave_FunctionDef(original_node: libcst.FunctionDef, updated_node: libcst.FunctionDef) → libcst.FunctionDef[source]¶

Apply signature, name, and body writes to functions.