ml_switcheroo.core.graph ======================== .. py:module:: ml_switcheroo.core.graph .. autoapi-nested-parse:: Graph Extraction Frontend. This module is responsible for analyzing Python Abstract Syntax Trees (ASTs) using LibCST and extracting a `LogicalGraph` Intermediate Representation via the `GraphExtractor`. It performs Provenance Tracking, mapping logical nodes back to their source CST nodes, enabling surgical patching later in the pipeline. Classes ------- .. autoapisummary:: ml_switcheroo.core.graph.LogicalNode ml_switcheroo.core.graph.LogicalEdge ml_switcheroo.core.graph.LogicalGraph ml_switcheroo.core.graph.GraphExtractor Functions --------- .. autoapisummary:: ml_switcheroo.core.graph.topological_sort Module Contents --------------- .. py:class:: LogicalNode Represents a computation unit (Layer) in the graph. .. py:attribute:: id :type: str Unique identifier (e.g. 'conv1'). .. py:attribute:: kind :type: str Operation type (e.g. 'Conv2d', 'Input', 'Output'). .. py:attribute:: metadata :type: Dict[str, str] Dictionary of configuration parameters (e.g. ``kernel_size=3``). .. py:class:: LogicalEdge Represents data flow between two nodes. .. py:attribute:: source :type: str Source node ID. .. py:attribute:: target :type: str Target node ID. .. py:class:: LogicalGraph Language-agnostic representation of the neural network structure. .. py:attribute:: name :type: str :value: 'Model' Name of the graph model/class. .. py:attribute:: nodes :type: List[LogicalNode] :value: [] Ordered list of nodes in the graph. .. py:attribute:: edges :type: List[LogicalEdge] :value: [] List of directed edges between nodes. .. py:function:: topological_sort(graph: LogicalGraph) -> List[LogicalNode] Sorts graph nodes by dependency order. Ensures that for every edge u -> v, u appears before v in the returned list. Handles disconnected components and cycles gracefully by appending unreachable nodes in their original definition order. :param graph: The logical graph to sort. :returns: List of nodes in execution order. .. py:class:: GraphExtractor Bases: :py:obj:`libcst.CSTVisitor` LibCST Visitor that extracts a LogicalGraph from Python source code. Two-Pass Logic: 1. **Init Pass**: Scans ``__init__`` or ``setup`` to register named layers assigned to ``self``. Populates the node registry and provenance map. 2. **Forward Pass**: Scans ``forward`` or ``__call__`` to trace variable usage. Builds edges between registered nodes based on data flow. .. attribute:: graph The constructed intermediate representation. :type: LogicalGraph .. attribute:: layer_registry Mapping of node IDs to LogicalNodes. :type: Dict[str, LogicalNode] .. attribute:: provenance Mapping of variable names to producer node IDs. :type: Dict[str, str] .. attribute:: node_map Provenance registry mapping Node ID -> CST Node. :type: Dict[str, cst.CSTNode] .. py:attribute:: graph .. py:attribute:: layer_registry :type: Dict[str, ml_switcheroo.compiler.ir.LogicalNode] .. py:attribute:: provenance :type: Dict[str, str] .. py:attribute:: node_map :type: Dict[str, libcst.CSTNode] .. py:attribute:: model_name :type: str :value: 'GeneratedNet' .. py:method:: visit_ClassDef(node: libcst.ClassDef) -> Optional[bool] Capture the model class name. .. py:method:: leave_ClassDef(node: libcst.ClassDef) -> None Exit class scope. .. py:method:: leave_Module(original_node: libcst.Module) -> None Finalize graph construction. .. py:method:: visit_FunctionDef(node: libcst.FunctionDef) -> Optional[bool] Detects entry into lifecycle methods. .. py:method:: leave_FunctionDef(node: libcst.FunctionDef) -> None Resets context flags upon exiting methods. .. py:method:: visit_Assign(node: libcst.Assign) -> Optional[bool] Handles assignment logic for both layer definition and data flow. .. py:method:: visit_Expr(node: libcst.Expr) -> Optional[bool] Handles standalone expression statements (e.g. `func(x)` without assignment). Used for 1:1 translations where top-level expressions are valid (e.g. MLIR roundtrips). .. py:method:: visit_Return(node: libcst.Return) -> Optional[bool] Handles return statements to identify Output nodes.