ml_switcheroo.analysis.symbol_table

Symbol Table and Type Inference Analysis with Control Flow Support.

This module provides a static analysis pass to infer variable types and scopes before rewriting occurs. It builds a mapping of AST nodes to inferred type objects, allowing the rewriter to make decisions based on the semantic type of a variable (e.g., “is this a Tensor?”) rather than just its lexical name.

The SymbolTableAnalyzer visitor populates a SymbolTable by tracking: 1. Imports: Mapping module aliases to ModuleType. 2. Assignments: Propagating types from RHS to LHS. 3. Scopes: Handling nested function/class definitions. 4. Control Flow: Handling type ambiguity in branches (Phi nodes) via Union types.

Classes

SymbolType

Base class for inferred types.

TensorType

Represents a Tensor object from a specific framework.

ModuleType

Represents an imported module or sub-module.

UnionType

Represents a union of potential types resulting from control flow divergence.

Scope

Represents a variable scope (Global, Class, or Function).

SymbolTable

Container for analysis results. Maps CST Nodes (by identity) to inferred Types.

SymbolTableAnalyzer

Static Analysis pass to populate the SymbolTable.

Module Contents

class ml_switcheroo.analysis.symbol_table.SymbolType[source]

Base class for inferred types.

name: str

A string representation of the type (e.g., ‘Tensor’).

__str__() str[source]

Returns the type name.

__eq__(other: object) bool[source]
class ml_switcheroo.analysis.symbol_table.TensorType[source]

Bases: SymbolType

Represents a Tensor object from a specific framework.

framework: str

The framework key (e.g. “torch” or “jax”) responsible for this tensor.

__eq__(other: object) bool[source]
class ml_switcheroo.analysis.symbol_table.ModuleType[source]

Bases: SymbolType

Represents an imported module or sub-module.

path: str

Fully qualified path string (e.g. “torch.nn”).

__eq__(other: object) bool[source]
class ml_switcheroo.analysis.symbol_table.UnionType(types: List[SymbolType])[source]

Bases: SymbolType

Represents a union of potential types resulting from control flow divergence.

types: List[SymbolType]
__str__() str[source]

Returns the type name.

__eq__(other: object) bool[source]
class ml_switcheroo.analysis.symbol_table.Scope(parent: Scope | None = None, name: str = '<root>')[source]

Represents a variable scope (Global, Class, or Function).

parent = None
name = '<root>'
symbols: Dict[str, SymbolType]
set(name: str, sym_type: SymbolType) None[source]

Register a symbol in the current scope.

Parameters:
  • name – Variable identifier.

  • sym_type – Inferred Type object.

get(name: str) SymbolType | None[source]

Resolve a symbol, traversing parent scopes.

Parameters:

name – Variable identifier to lookup.

Returns:

The SymbolType if found, else None.

snapshot() Dict[str, SymbolType][source]

Returns a shallow copy of the current symbol table for branching.

class ml_switcheroo.analysis.symbol_table.SymbolTable[source]

Container for analysis results. Maps CST Nodes (by identity) to inferred Types.

record_type(node: libcst.CSTNode, sym_type: SymbolType) None[source]

Associates a CST node with a type.

Parameters:
  • node – The CST node.

  • sym_type – The determined type.

get_type(node: libcst.CSTNode) SymbolType | None[source]

Retrieves the inferred type for a CST node.

Parameters:

node – The CST node to inspect.

Returns:

The stored SymbolType or None.

class ml_switcheroo.analysis.symbol_table.SymbolTableAnalyzer(semantics: ml_switcheroo.semantics.manager.SemanticsManager)[source]

Bases: libcst.CSTVisitor

Static Analysis pass to populate the SymbolTable. Runs post-order traversal logic (via leave methods) to propagate types bottom-up. Implements shallow control flow inference for If/Else and Loops.

semantics
table
root_scope
current_scope
visit_ClassDef(node: libcst.ClassDef) None[source]

Enters class scope.

leave_ClassDef(node: libcst.ClassDef) None[source]

Exits class scope.

visit_FunctionDef(node: libcst.FunctionDef) None[source]

Enters function scope.

leave_FunctionDef(node: libcst.FunctionDef) None[source]

Exits function scope.

visit_If(node: libcst.If) bool[source]

Handle branching logic. 1. Snapshot state. 2. Visit body -> State_Body. 3. Revert to Snapshot. 4. Visit Else (if any) -> State_Else. 5. Merge (State_Body, State_Else).

visit_For(node: libcst.For) bool[source]

Handle loop logic. Loops may execute 0 times or N times, introducing potential ambiguity. We merge the state after loop body with the state before loop.

visit_While(node: libcst.While) bool[source]

Handle while loop logic.

leave_IfExp(node: libcst.IfExp) None[source]

Infers type for ternary expression: A if C else B.

leave_Import(node: libcst.Import) None[source]

Track imports. e.g. import torch -> symbols[‘torch’] = ModuleType(name=’Module’, path=’torch’)

leave_ImportFrom(node: libcst.ImportFrom) None[source]

Track from-imports. e.g. from torch import nn -> symbols[‘nn’] = ModuleType(name=’Module’, path=’torch.nn’)

leave_Assign(node: libcst.Assign) None[source]

Propagate type from RHS to LHS. x = torch.randn() -> x is Tensor.

leave_Name(node: libcst.Name) None[source]

Look up variable in scope.

leave_Attribute(node: libcst.Attribute) None[source]

Resolve attributes based on their receiver type. If x is Module(‘torch’), x.nn is Module(‘torch.nn’). If x is Tensor, x.shape might be recorded etc.

leave_Call(node: libcst.Call) None[source]

Infer return type of a call. 1. Resolve function fully qualified name. 2. Check SemanticsManager for return type.