ml_switcheroo.core.rewriter.passes.structure¶
Structural Rewriting Pass.
This module consolidates all structural transformation logic, including:
1. Class Inheritance Rewriting: Swapping framework base classes (e.g., torch.nn.Module -> flax.nnx.Module).
2. Function Signature Rewriting: Injecting or stripping state/context arguments (e.g., rngs, ctx).
3. Method Renaming: Mapping lifecycle methods (e.g., forward -> __call__).
4. Body Injection: Handling super().__init__ calls and preamble injection.
5. Type Annotation Rewriting: Mapping framework-specific types (e.g., torch.Tensor -> jax.Array).
Classes¶
Pass responsible for modifying the structural scaffolding of the code. |
|
LibCST Transformer encapsulating all structural logic. |
Module Contents¶
- class ml_switcheroo.core.rewriter.passes.structure.StructuralPass[source]¶
Bases:
ml_switcheroo.core.rewriter.interface.RewriterPassPass responsible for modifying the structural scaffolding of the code.
It transforms class definitions, function signatures, and type hints to match the idioms of the target framework.
- transform(module: libcst.Module, context: ml_switcheroo.core.rewriter.context.RewriterContext) libcst.Module[source]¶
Executes the structural transformation.
- Parameters:
module – The source CST.
context – Shared rewriter state.
- Returns:
The transformed CST.
- class ml_switcheroo.core.rewriter.passes.structure.StructuralTransformer(context: ml_switcheroo.core.rewriter.context.RewriterContext)[source]¶
Bases:
libcst.CSTTransformerLibCST Transformer encapsulating all structural logic.
Maintains internal state for scope depth, annotation context, and target framework traits.
- context¶
- property target_traits: ml_switcheroo.semantics.schema.StructuralTraits¶
Lazily load structural traits for the target framework.
- leave_Module(original_node: libcst.Module, updated_node: libcst.Module) libcst.Module[source]¶
Injects accumulated module-level preamble statements (e.g. imports, shim classes) requested by plugins during the rewrite. Flushes and clears the buffer to prevent double injection in subsequent passes.
- visit_Annotation(node: libcst.Annotation) bool | None[source]¶
Flag entry into a type annotation context.
- leave_Annotation(original_node: libcst.Annotation, updated_node: libcst.Annotation) libcst.Annotation[source]¶
Flag exit from a type annotation context.
- leave_Name(original_node: libcst.Name, updated_node: libcst.Name) libcst.BaseExpression[source]¶
Rewrite type names if inside an annotation.
- leave_Attribute(original_node: libcst.Attribute, updated_node: libcst.Attribute) libcst.BaseExpression[source]¶
Rewrite dotted type attributes (e.g. torch.Tensor) if inside an annotation. This fixes issue where complex types like torch.Tensor fail to rewrite because leave_Name only handles the leaves individually without context.
- visit_ClassDef(node: libcst.ClassDef) bool | None[source]¶
Detect Module inheritance to set processing state.
- leave_ClassDef(original_node: libcst.ClassDef, updated_node: libcst.ClassDef) libcst.ClassDef | libcst.CSTNode[source]¶
Rewrite class inheritance if necessary.