ml_switcheroo.analysis.symbol_table¶
Symbol Table and Type Inference Analysis with Control Flow Support.
This module provides a static analysis pass to infer variable types and scopes before rewriting occurs. It builds a mapping of AST nodes to inferred type objects, allowing the rewriter to make decisions based on the semantic type of a variable (e.g., “is this a Tensor?”) rather than just its lexical name.
The SymbolTableAnalyzer visitor populates a SymbolTable by tracking: 1. Imports: Mapping module aliases to ModuleType. 2. Assignments: Propagating types from RHS to LHS. 3. Scopes: Handling nested function/class definitions. 4. Control Flow: Handling type ambiguity in branches (Phi nodes) via Union types.
Classes¶
Base class for inferred types. |
|
Represents a Tensor object from a specific framework. |
|
Represents an imported module or sub-module. |
|
Represents a union of potential types resulting from control flow divergence. |
|
Represents a variable scope (Global, Class, or Function). |
|
Container for analysis results. Maps CST Nodes (by identity) to inferred Types. |
|
Static Analysis pass to populate the SymbolTable. |
Module Contents¶
- class ml_switcheroo.analysis.symbol_table.SymbolType[source]¶
Base class for inferred types.
- name: str¶
A string representation of the type (e.g., ‘Tensor’).
- class ml_switcheroo.analysis.symbol_table.TensorType[source]¶
Bases:
SymbolTypeRepresents a Tensor object from a specific framework.
- framework: str¶
The framework key (e.g. “torch” or “jax”) responsible for this tensor.
- class ml_switcheroo.analysis.symbol_table.ModuleType[source]¶
Bases:
SymbolTypeRepresents an imported module or sub-module.
- path: str¶
Fully qualified path string (e.g. “torch.nn”).
- class ml_switcheroo.analysis.symbol_table.UnionType(types: List[SymbolType])[source]¶
Bases:
SymbolTypeRepresents a union of potential types resulting from control flow divergence.
- types: List[SymbolType]¶
- class ml_switcheroo.analysis.symbol_table.Scope(parent: Scope | None = None, name: str = '<root>')[source]¶
Represents a variable scope (Global, Class, or Function).
- parent = None¶
- name = '<root>'¶
- symbols: Dict[str, SymbolType]¶
- set(name: str, sym_type: SymbolType) None[source]¶
Register a symbol in the current scope.
- Parameters:
name – Variable identifier.
sym_type – Inferred Type object.
- get(name: str) SymbolType | None[source]¶
Resolve a symbol, traversing parent scopes.
- Parameters:
name – Variable identifier to lookup.
- Returns:
The SymbolType if found, else None.
- snapshot() Dict[str, SymbolType][source]¶
Returns a shallow copy of the current symbol table for branching.
- class ml_switcheroo.analysis.symbol_table.SymbolTable[source]¶
Container for analysis results. Maps CST Nodes (by identity) to inferred Types.
- record_type(node: libcst.CSTNode, sym_type: SymbolType) None[source]¶
Associates a CST node with a type.
- Parameters:
node – The CST node.
sym_type – The determined type.
- get_type(node: libcst.CSTNode) SymbolType | None[source]¶
Retrieves the inferred type for a CST node.
- Parameters:
node – The CST node to inspect.
- Returns:
The stored SymbolType or None.
- class ml_switcheroo.analysis.symbol_table.SymbolTableAnalyzer(semantics: ml_switcheroo.semantics.manager.SemanticsManager)[source]¶
Bases:
libcst.CSTVisitorStatic Analysis pass to populate the SymbolTable. Runs post-order traversal logic (via leave methods) to propagate types bottom-up. Implements shallow control flow inference for If/Else and Loops.
- semantics¶
- table¶
- root_scope¶
- current_scope¶
- visit_If(node: libcst.If) bool[source]¶
Handle branching logic. 1. Snapshot state. 2. Visit body -> State_Body. 3. Revert to Snapshot. 4. Visit Else (if any) -> State_Else. 5. Merge (State_Body, State_Else).
- visit_For(node: libcst.For) bool[source]¶
Handle loop logic. Loops may execute 0 times or N times, introducing potential ambiguity. We merge the state after loop body with the state before loop.
- leave_Import(node: libcst.Import) None[source]¶
Track imports. e.g. import torch -> symbols[‘torch’] = ModuleType(name=’Module’, path=’torch’)
- leave_ImportFrom(node: libcst.ImportFrom) None[source]¶
Track from-imports. e.g. from torch import nn -> symbols[‘nn’] = ModuleType(name=’Module’, path=’torch.nn’)
- leave_Assign(node: libcst.Assign) None[source]¶
Propagate type from RHS to LHS. x = torch.randn() -> x is Tensor.