ml_switcheroo.core.rewriter.base¶

Base Rewriter Implementation with Alias and Scope Resolution.

This module provides the BaseRewriter class, which serves as the foundation for the PivotRewriter. It handles:

  1. State Management: Tracking the current scope (global vs class vs function) to handle stateful variable detection.

  2. Alias Resolution: Tracking import as statements to resolve t.abs back to torch.abs or np.sum to numpy.sum.

  3. Error Reporting: Collecting failures during the AST walk to be bubbled up to the ASTEngine.

  4. Hook Infrastructure: initializing the HookContext used by plugins.

  5. Global Injection: Handling file-level preamble injection (leave_Module).

  6. Import Injection: Processing dynamic import requirements from variants.

Classes¶

BaseRewriter

The base class for AST transformation traversal.

Module Contents¶

class ml_switcheroo.core.rewriter.base.BaseRewriter(semantics: ml_switcheroo.semantics.manager.SemanticsManager, config: ml_switcheroo.config.RuntimeConfig)¶

Bases: libcst.CSTTransformer

The base class for AST transformation traversal.

Provides common utilities for scope tracking, alias resolution, and error bubbling, which are utilized by the specific Mixins (CallMixin, StructureMixin, etc.).

semantics¶
config¶
source_fw = ''¶
target_fw = ''¶
strict_mode¶
ctx¶
leave_Module(original_node: libcst.Module, updated_node: libcst.Module) → libcst.Module¶

Injects module-level preambles (e.g. Shim classes) requested by plugins. Ensures injection happens after docstrings to maintain valid Python help text.

Parameters:
  • original_node – Logic before transformation.

  • updated_node – Logic after transformation.

Returns:

The module with injected preambles.

Return type:

cst.Module

visit_Import(node: libcst.Import) → bool | None¶

Scans import ... statements to populate the alias map. Example: import torch.nn as nn -> _alias_map['nn'] = 'torch.nn'.

Parameters:

node – Import statement node.

Returns:

False to stop traversal of children.

Return type:

Optional[bool]

visit_ImportFrom(node: libcst.ImportFrom) → bool | None¶

Scans from ... import ... statements to populate the alias map. Example: from torch import nn -> _alias_map['nn'] = 'torch.nn'.

Parameters:

node – ImportFrom statement node.

Returns:

False to stop traversal of children.

Return type:

Optional[bool]

visit_SimpleStatementLine(node: libcst.SimpleStatementLine) → bool | None¶

Resets error tracking at the start of each line. Errors bubble up from children (Expressions) to this Statement handler.

Parameters:

node – The statement line node.

Returns:

True to continue traversal.

Return type:

Optional[bool]

leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) → libcst.SimpleStatementLine | libcst.FlattenSentinel¶

Handles error bubbling from expression rewrites.

If errors occurred during processing of this line’s children (e.g. failing to rewrite a function call), wrap the line in an EscapeHatch.

Prioritizes errors (reverting to original code) over warnings (using updated code).

Parameters:
  • original_node – The node before children were visited.

  • updated_node – The node after children transformation.

Returns:

The resulted node (possibly wrapped with comments).

Return type:

Union[cst.SimpleStatementLine, cst.FlattenSentinel]