ml_switcheroo.analysis.symbol_table =================================== .. py:module:: ml_switcheroo.analysis.symbol_table .. autoapi-nested-parse:: Symbol Table and Type Inference Analysis with Control Flow Support. This module provides a static analysis pass to infer variable types and scopes before rewriting occurs. It builds a mapping of AST nodes to inferred type objects, allowing the rewriter to make decisions based on the semantic type of a variable (e.g., "is this a Tensor?") rather than just its lexical name. The `SymbolTableAnalyzer` visitor populates a `SymbolTable` by tracking: 1. **Imports**: Mapping module aliases to `ModuleType`. 2. **Assignments**: Propagating types from RHS to LHS. 3. **Scopes**: Handling nested function/class definitions. 4. **Control Flow**: Handling type ambiguity in branches (Phi nodes) via Union types. Classes ------- .. autoapisummary:: ml_switcheroo.analysis.symbol_table.SymbolType ml_switcheroo.analysis.symbol_table.TensorType ml_switcheroo.analysis.symbol_table.ModuleType ml_switcheroo.analysis.symbol_table.UnionType ml_switcheroo.analysis.symbol_table.Scope ml_switcheroo.analysis.symbol_table.SymbolTable ml_switcheroo.analysis.symbol_table.SymbolTableAnalyzer Module Contents --------------- .. py:class:: SymbolType Base class for inferred types. .. py:attribute:: name :type: str A string representation of the type (e.g., 'Tensor'). .. py:method:: __str__() -> str Returns the type name. .. py:method:: __eq__(other: object) -> bool .. py:class:: TensorType Bases: :py:obj:`SymbolType` Represents a Tensor object from a specific framework. .. py:attribute:: framework :type: str The framework key (e.g. "torch" or "jax") responsible for this tensor. .. py:method:: __eq__(other: object) -> bool .. py:class:: ModuleType Bases: :py:obj:`SymbolType` Represents an imported module or sub-module. .. py:attribute:: path :type: str Fully qualified path string (e.g. "torch.nn"). .. py:method:: __eq__(other: object) -> bool .. py:class:: UnionType(types: List[SymbolType]) Bases: :py:obj:`SymbolType` Represents a union of potential types resulting from control flow divergence. .. py:attribute:: types :type: List[SymbolType] .. py:method:: __str__() -> str Returns the type name. .. py:method:: __eq__(other: object) -> bool .. py:class:: Scope(parent: Optional[Scope] = None, name: str = '') Represents a variable scope (Global, Class, or Function). .. py:attribute:: parent :value: None .. py:attribute:: name :value: '' .. py:attribute:: symbols :type: Dict[str, SymbolType] .. py:method:: set(name: str, sym_type: SymbolType) -> None Register a symbol in the current scope. :param name: Variable identifier. :param sym_type: Inferred Type object. .. py:method:: get(name: str) -> Optional[SymbolType] Resolve a symbol, traversing parent scopes. :param name: Variable identifier to lookup. :returns: The SymbolType if found, else None. .. py:method:: snapshot() -> Dict[str, SymbolType] Returns a shallow copy of the current symbol table for branching. .. py:class:: SymbolTable Container for analysis results. Maps CST Nodes (by identity) to inferred Types. .. py:method:: record_type(node: libcst.CSTNode, sym_type: SymbolType) -> None Associates a CST node with a type. :param node: The CST node. :param sym_type: The determined type. .. py:method:: get_type(node: libcst.CSTNode) -> Optional[SymbolType] Retrieves the inferred type for a CST node. :param node: The CST node to inspect. :returns: The stored SymbolType or None. .. py:class:: SymbolTableAnalyzer(semantics: ml_switcheroo.semantics.manager.SemanticsManager) Bases: :py:obj:`libcst.CSTVisitor` Static Analysis pass to populate the SymbolTable. Runs post-order traversal logic (via leave methods) to propagate types bottom-up. Implements shallow control flow inference for If/Else and Loops. .. py:attribute:: semantics .. py:attribute:: table .. py:attribute:: root_scope .. py:attribute:: current_scope .. py:method:: visit_ClassDef(node: libcst.ClassDef) -> None Enters class scope. .. py:method:: leave_ClassDef(node: libcst.ClassDef) -> None Exits class scope. .. py:method:: visit_FunctionDef(node: libcst.FunctionDef) -> None Enters function scope. .. py:method:: leave_FunctionDef(node: libcst.FunctionDef) -> None Exits function scope. .. py:method:: visit_If(node: libcst.If) -> bool Handle branching logic. 1. Snapshot state. 2. Visit body -> State_Body. 3. Revert to Snapshot. 4. Visit Else (if any) -> State_Else. 5. Merge (State_Body, State_Else). .. py:method:: visit_For(node: libcst.For) -> bool Handle loop logic. Loops may execute 0 times or N times, introducing potential ambiguity. We merge the state after loop body with the state before loop. .. py:method:: visit_While(node: libcst.While) -> bool Handle while loop logic. .. py:method:: leave_IfExp(node: libcst.IfExp) -> None Infers type for ternary expression: `A if C else B`. .. py:method:: leave_Import(node: libcst.Import) -> None Track imports. e.g. `import torch` -> symbols['torch'] = ModuleType(name='Module', path='torch') .. py:method:: leave_ImportFrom(node: libcst.ImportFrom) -> None Track from-imports. e.g. `from torch import nn` -> symbols['nn'] = ModuleType(name='Module', path='torch.nn') .. py:method:: leave_Assign(node: libcst.Assign) -> None Propagate type from RHS to LHS. x = torch.randn() -> x is Tensor. .. py:method:: leave_Name(node: libcst.Name) -> None Look up variable in scope. .. py:method:: leave_Attribute(node: libcst.Attribute) -> None Resolve attributes based on their receiver type. If `x` is Module('torch'), `x.nn` is Module('torch.nn'). If `x` is Tensor, `x.shape` might be recorded etc. .. py:method:: leave_Call(node: libcst.Call) -> None Infer return type of a call. 1. Resolve function fully qualified name. 2. Check SemanticsManager for return type.