ml_switcheroo.core.rewriter¶

Rewriter Package.

This package provides the PivotRewriter class, composed of several mixins to handle specific aspects of the AST transformation: - Structure: Class and Function definitions. - Calls: Function invocations. - Contexts: With blocks. - Attributes: Attribute access. - Normalization: Argument mapping. - ControlFlow: Loop and branching logic. - Decorators: Decorator handling.

Submodules¶

Classes¶

CallMixin

Mixin for transforming Call nodes and unpacking Assignments.

NormalizationMixin

Mixin class providing argument normalization and operator transformation logic.

AttributeMixin

Mixin for transforming attributes and tracking assignments.

StructureMixin

Composite mixin for all structural rewriting tasks.

DecoratorMixin

Mixin for transforming Decorator nodes.

ControlFlowMixin

Mixin for visiting Control Flow nodes (For, While, If).

ContextMixin

Mixin for transforming With blocks.

BaseRewriter

The base class for AST transformation traversal.

PivotRewriter

The main AST transformer for ml-switcheroo.

Package Contents¶

class ml_switcheroo.core.rewriter.CallMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: ml_switcheroo.core.rewriter.normalization.NormalizationMixin, ml_switcheroo.core.rewriter.base.BaseRewriter

Mixin for transforming Call nodes and unpacking Assignments.

Responsible for: 1. Handling functional apply patterns (Flax). 2. Lifecycle method stripping (.to(), .cuda()). 3. Plugin dispatch. 4. Standard API pivoting (Lookup -> Normalize -> Rewrite). 5. Output Transformation (Indexing, Casting).

leave_Assign(original_node: libcst.Assign, updated_node: libcst.Assign) → libcst.Assign¶

Handles assignment unwrapping for Functional -> OOP transitions.

Scenario: y, updates = layer.apply(vars, x) Target: y = layer(x) (NNX/Torch style)

leave_Call(original: libcst.Call, updated: libcst.Call) → libcst.Call | libcst.BinaryOperation | libcst.UnaryOperation | libcst.CSTNode¶

Rewrites function calls with detailed Trace Logging. Implements Logic 4: Layer Init State Threading.

class ml_switcheroo.core.rewriter.NormalizationMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: ml_switcheroo.core.rewriter.base.BaseRewriter

Mixin class providing argument normalization and operator transformation logic.

class ml_switcheroo.core.rewriter.AttributeMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: ml_switcheroo.core.rewriter.base.BaseRewriter

Mixin for transforming attributes and tracking assignments.

leave_Assign(original_node: libcst.Assign, updated_node: libcst.Assign) → libcst.Assign¶

Track stateful assignments (e.g. self.layer = Linear(…)) to determine if an object variable requires special handling later.

leave_Attribute(original: libcst.Attribute, updated: libcst.Attribute) → libcst.BaseExpression¶

Visits attributes (e.g. torch.float32).

class ml_switcheroo.core.rewriter.StructureMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: ml_switcheroo.core.rewriter.structure_class.ClassStructureMixin, ml_switcheroo.core.rewriter.structure_func.FuncStructureMixin, ml_switcheroo.core.rewriter.structure_types.TypeStructureMixin

Composite mixin for all structural rewriting tasks. Inherits from granular mixins to assemble the full feature set.

class ml_switcheroo.core.rewriter.DecoratorMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: ml_switcheroo.core.rewriter.base.BaseRewriter

Mixin for transforming Decorator nodes. Part of PivotRewriter.

leave_Decorator(original_node: libcst.Decorator, updated_node: libcst.Decorator) → libcst.Decorator | libcst.RemovalSentinel¶

Processes decorators attached to functions or classes.

Logic: 1. Identifies the decorator name from original_node to ensure we key off the

Source Framework API, even if CallMixin modified the children in updated_node.

  1. Looks up the semantic definition.

  2. If the target variant is explicitly null (None), removes the decorator.

  3. If the target variant specifies a new API, rewrites the name.

class ml_switcheroo.core.rewriter.ControlFlowMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: ml_switcheroo.core.rewriter.base.BaseRewriter

Mixin for visiting Control Flow nodes (For, While, If).

leave_For(original_node: libcst.For, updated_node: libcst.For) → libcst.For | libcst.CSTNode¶

Invokes loop transformation logic.

Implements a priority chain: 1. Static Unroll: Checks transform_for_loop_static. If the loop index

is constant, unrolls it (Optimization).

  1. General Transform: Checks transform_for_loop. Handles general case logic or applies safety warnings (e.g., Escape Hatch for JAX).

Parameters:
  • original_node (cst.For) – The node before transformation.

  • updated_node (cst.For) – The node after child visitors have run.

Returns:

The transformed node (potentially a FlattenSentinel of unrolled statements) or the original node if untouched.

Return type:

Union[cst.For, cst.CSTNode]

class ml_switcheroo.core.rewriter.ContextMixin(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: ml_switcheroo.core.rewriter.base.BaseRewriter

Mixin for transforming With blocks.

leave_With(original_node: libcst.With, updated_node: libcst.With) → libcst.With | libcst.FlattenSentinel¶

Processes ‘with’ statements.

Logic: 1. Iterate over with items (expressions). 2. Identify if the expression corresponds to a CONTEXT OpType in semantics. 3. Check transformation logic:

  • If strip_context, identifying marker is found, lift the body.

  • Otherwise, allow CallMixin (which runs on children) to have already renamed the API.

class ml_switcheroo.core.rewriter.BaseRewriter(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: libcst.CSTTransformer

The base class for AST transformation traversal.

Provides common utilities for scope tracking, alias resolution, and error bubbling, which are utilized by the specific Mixins (CallMixin, StructureMixin, etc.).

semantics¶
config¶
source_fw = ''¶
target_fw = ''¶
strict_mode¶
ctx¶
leave_Module(original_node: libcst.Module, updated_node: libcst.Module) → libcst.Module¶

Injects module-level preambles (e.g. Shim classes) requested by plugins. Ensures injection happens after docstrings to maintain valid Python help text.

Parameters:
  • original_node – Logic before transformation.

  • updated_node – Logic after transformation.

Returns:

The module with injected preambles.

Return type:

cst.Module

visit_Import(node: libcst.Import) → bool | None¶

Scans import ... statements to populate the alias map. Example: import torch.nn as nn -> _alias_map['nn'] = 'torch.nn'.

Parameters:

node – Import statement node.

Returns:

False to stop traversal of children.

Return type:

Optional[bool]

visit_ImportFrom(node: libcst.ImportFrom) → bool | None¶

Scans from ... import ... statements to populate the alias map. Example: from torch import nn -> _alias_map['nn'] = 'torch.nn'.

Parameters:

node – ImportFrom statement node.

Returns:

False to stop traversal of children.

Return type:

Optional[bool]

visit_SimpleStatementLine(node: libcst.SimpleStatementLine) → bool | None¶

Resets error tracking at the start of each line. Errors bubble up from children (Expressions) to this Statement handler.

Parameters:

node – The statement line node.

Returns:

True to continue traversal.

Return type:

Optional[bool]

leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) → libcst.SimpleStatementLine | libcst.FlattenSentinel¶

Handles error bubbling from expression rewrites.

If errors occurred during processing of this line’s children (e.g. failing to rewrite a function call), wrap the line in an EscapeHatch.

Prioritizes errors (reverting to original code) over warnings (using updated code).

Parameters:
  • original_node – The node before children were visited.

  • updated_node – The node after children transformation.

Returns:

The resulted node (possibly wrapped with comments).

Return type:

Union[cst.SimpleStatementLine, cst.FlattenSentinel]

class ml_switcheroo.core.rewriter.PivotRewriter(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: contexts.ContextMixin, control_flow.ControlFlowMixin, decorators.DecoratorMixin, calls.CallMixin, normalization.NormalizationMixin, attributes.AttributeMixin, structure.StructureMixin, base.BaseRewriter

The main AST transformer for ml-switcheroo.

Inherits functionality from component identifiers (Mixins) and the base transformer logic. This class is the entry point for the ASTEngine.