ml_switcheroo.core.rewriter¶
Rewriter Package.
This package provides the PivotRewriter class, composed of several mixins to handle specific aspects of the AST transformation: - Structure: Class and Function definitions. - Calls: Function invocations. - Contexts: With blocks. - Attributes: Attribute access. - Normalization: Argument mapping. - ControlFlow: Loop and branching logic. - Decorators: Decorator handling.
Submodules¶
- ml_switcheroo.core.rewriter.attributes
- ml_switcheroo.core.rewriter.base
- ml_switcheroo.core.rewriter.calls
- ml_switcheroo.core.rewriter.contexts
- ml_switcheroo.core.rewriter.control_flow
- ml_switcheroo.core.rewriter.decorators
- ml_switcheroo.core.rewriter.initializers
- ml_switcheroo.core.rewriter.normalization
- ml_switcheroo.core.rewriter.structure
- ml_switcheroo.core.rewriter.structure_class
- ml_switcheroo.core.rewriter.structure_func
- ml_switcheroo.core.rewriter.structure_types
- ml_switcheroo.core.rewriter.types
Classes¶
Mixin for transforming Call nodes and unpacking Assignments. |
|
Mixin class providing argument normalization and operator transformation logic. |
|
Mixin for transforming attributes and tracking assignments. |
|
Composite mixin for all structural rewriting tasks. |
|
Mixin for transforming Decorator nodes. |
|
Mixin for visiting Control Flow nodes (For, While, If). |
|
Mixin for transforming With blocks. |
|
The base class for AST transformation traversal. |
|
The main AST transformer for ml-switcheroo. |
Package Contents¶
- class ml_switcheroo.core.rewriter.CallMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶
Bases:
ml_switcheroo.core.rewriter.normalization.NormalizationMixin,ml_switcheroo.core.rewriter.base.BaseRewriterMixin for transforming Call nodes and unpacking Assignments.
Responsible for: 1. Handling functional apply patterns (Flax). 2. Lifecycle method stripping (.to(), .cuda()). 3. Plugin dispatch. 4. Standard API pivoting (Lookup -> Normalize -> Rewrite). 5. Output Transformation (Indexing, Casting).
- leave_Assign(original_node: libcst.Assign, updated_node: libcst.Assign) libcst.Assign¶
Handles assignment unwrapping for Functional -> OOP transitions.
Scenario: y, updates = layer.apply(vars, x) Target: y = layer(x) (NNX/Torch style)
- leave_Call(original: libcst.Call, updated: libcst.Call) libcst.Call | libcst.BinaryOperation | libcst.UnaryOperation | libcst.CSTNode¶
Rewrites function calls with detailed Trace Logging. Implements Logic 4: Layer Init State Threading.
- class ml_switcheroo.core.rewriter.NormalizationMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶
Bases:
ml_switcheroo.core.rewriter.base.BaseRewriterMixin class providing argument normalization and operator transformation logic.
- class ml_switcheroo.core.rewriter.AttributeMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶
Bases:
ml_switcheroo.core.rewriter.base.BaseRewriterMixin for transforming attributes and tracking assignments.
- leave_Assign(original_node: libcst.Assign, updated_node: libcst.Assign) libcst.Assign¶
Track stateful assignments (e.g. self.layer = Linear(…)) to determine if an object variable requires special handling later.
- leave_Attribute(original: libcst.Attribute, updated: libcst.Attribute) libcst.BaseExpression¶
Visits attributes (e.g. torch.float32).
- class ml_switcheroo.core.rewriter.StructureMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶
Bases:
ml_switcheroo.core.rewriter.structure_class.ClassStructureMixin,ml_switcheroo.core.rewriter.structure_func.FuncStructureMixin,ml_switcheroo.core.rewriter.structure_types.TypeStructureMixinComposite mixin for all structural rewriting tasks. Inherits from granular mixins to assemble the full feature set.
- class ml_switcheroo.core.rewriter.DecoratorMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶
Bases:
ml_switcheroo.core.rewriter.base.BaseRewriterMixin for transforming Decorator nodes. Part of PivotRewriter.
- leave_Decorator(original_node: libcst.Decorator, updated_node: libcst.Decorator) libcst.Decorator | libcst.RemovalSentinel¶
Processes decorators attached to functions or classes.
Logic: 1. Identifies the decorator name from original_node to ensure we key off the
Source Framework API, even if CallMixin modified the children in updated_node.
Looks up the semantic definition.
If the target variant is explicitly null (None), removes the decorator.
If the target variant specifies a new API, rewrites the name.
- class ml_switcheroo.core.rewriter.ControlFlowMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶
Bases:
ml_switcheroo.core.rewriter.base.BaseRewriterMixin for visiting Control Flow nodes (For, While, If).
- leave_For(original_node: libcst.For, updated_node: libcst.For) libcst.For | libcst.CSTNode¶
Invokes loop transformation logic.
Implements a priority chain: 1. Static Unroll: Checks
transform_for_loop_static. If the loop indexis constant, unrolls it (Optimization).
General Transform: Checks
transform_for_loop. Handles general case logic or applies safety warnings (e.g., Escape Hatch for JAX).
- Parameters:
original_node (cst.For) – The node before transformation.
updated_node (cst.For) – The node after child visitors have run.
- Returns:
The transformed node (potentially a FlattenSentinel of unrolled statements) or the original node if untouched.
- Return type:
Union[cst.For, cst.CSTNode]
- class ml_switcheroo.core.rewriter.ContextMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶
Bases:
ml_switcheroo.core.rewriter.base.BaseRewriterMixin for transforming With blocks.
- leave_With(original_node: libcst.With, updated_node: libcst.With) libcst.With | libcst.FlattenSentinel¶
Processes ‘with’ statements.
Logic: 1. Iterate over with items (expressions). 2. Identify if the expression corresponds to a CONTEXT OpType in semantics. 3. Check transformation logic:
If strip_context, identifying marker is found, lift the body.
Otherwise, allow CallMixin (which runs on children) to have already renamed the API.
- class ml_switcheroo.core.rewriter.BaseRewriter(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶
Bases:
libcst.CSTTransformerThe base class for AST transformation traversal.
Provides common utilities for scope tracking, alias resolution, and error bubbling, which are utilized by the specific Mixins (CallMixin, StructureMixin, etc.).
- semantics¶
- config¶
- source_fw = ''¶
- target_fw = ''¶
- strict_mode¶
- ctx¶
- leave_Module(original_node: libcst.Module, updated_node: libcst.Module) libcst.Module¶
Injects module-level preambles (e.g. Shim classes) requested by plugins. Ensures injection happens after docstrings to maintain valid Python help text.
- Parameters:
original_node – Logic before transformation.
updated_node – Logic after transformation.
- Returns:
The module with injected preambles.
- Return type:
cst.Module
- visit_Import(node: libcst.Import) bool | None¶
Scans
import ...statements to populate the alias map. Example:import torch.nn as nn->_alias_map['nn'] = 'torch.nn'.- Parameters:
node – Import statement node.
- Returns:
False to stop traversal of children.
- Return type:
Optional[bool]
- visit_ImportFrom(node: libcst.ImportFrom) bool | None¶
Scans
from ... import ...statements to populate the alias map. Example:from torch import nn->_alias_map['nn'] = 'torch.nn'.- Parameters:
node – ImportFrom statement node.
- Returns:
False to stop traversal of children.
- Return type:
Optional[bool]
- visit_SimpleStatementLine(node: libcst.SimpleStatementLine) bool | None¶
Resets error tracking at the start of each line. Errors bubble up from children (Expressions) to this Statement handler.
- Parameters:
node – The statement line node.
- Returns:
True to continue traversal.
- Return type:
Optional[bool]
- leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) libcst.SimpleStatementLine | libcst.FlattenSentinel¶
Handles error bubbling from expression rewrites.
If errors occurred during processing of this line’s children (e.g. failing to rewrite a function call), wrap the line in an
EscapeHatch.Prioritizes errors (reverting to original code) over warnings (using updated code).
- Parameters:
original_node – The node before children were visited.
updated_node – The node after children transformation.
- Returns:
The resulted node (possibly wrapped with comments).
- Return type:
Union[cst.SimpleStatementLine, cst.FlattenSentinel]
- class ml_switcheroo.core.rewriter.PivotRewriter(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶
Bases:
contexts.ContextMixin,control_flow.ControlFlowMixin,decorators.DecoratorMixin,calls.CallMixin,normalization.NormalizationMixin,attributes.AttributeMixin,structure.StructureMixin,base.BaseRewriterThe main AST transformer for ml-switcheroo.
Inherits functionality from component identifiers (Mixins) and the base transformer logic. This class is the entry point for the ASTEngine.