ml_switcheroo.core.graph¶

Graph Extraction Frontend.

This module is responsible for analyzing Python Abstract Syntax Trees (ASTs) using LibCST and extracting a LogicalGraph Intermediate Representation via the GraphExtractor.

It performs Provenance Tracking, mapping logical nodes back to their source CST nodes, enabling surgical patching later in the pipeline.

Classes¶

LogicalNode

Represents a computation unit (Layer) in the graph.

LogicalEdge

Represents data flow between two nodes.

LogicalGraph

Language-agnostic representation of the neural network structure.

GraphExtractor

LibCST Visitor that extracts a LogicalGraph from Python source code.

Functions¶

topological_sort(→ List[LogicalNode])

Sorts graph nodes by dependency order.

Module Contents¶

class ml_switcheroo.core.graph.LogicalNode[source]¶

Represents a computation unit (Layer) in the graph.

id: str¶

Unique identifier (e.g. ‘conv1’).

kind: str¶

Operation type (e.g. ‘Conv2d’, ‘Input’, ‘Output’).

metadata: Dict[str, str]¶

Dictionary of configuration parameters (e.g. kernel_size=3).

class ml_switcheroo.core.graph.LogicalEdge[source]¶

Represents data flow between two nodes.

source: str¶

Source node ID.

target: str¶

Target node ID.

class ml_switcheroo.core.graph.LogicalGraph[source]¶

Language-agnostic representation of the neural network structure.

name: str = 'Model'¶

Name of the graph model/class.

nodes: List[LogicalNode] = []¶

Ordered list of nodes in the graph.

edges: List[LogicalEdge] = []¶

List of directed edges between nodes.

ml_switcheroo.core.graph.topological_sort(graph: LogicalGraph) → List[LogicalNode][source]¶

Sorts graph nodes by dependency order.

Ensures that for every edge u -> v, u appears before v in the returned list. Handles disconnected components and cycles gracefully by appending unreachable nodes in their original definition order.

Parameters:

graph – The logical graph to sort.

Returns:

List of nodes in execution order.

class ml_switcheroo.core.graph.GraphExtractor[source]¶

Bases: libcst.CSTVisitor

LibCST Visitor that extracts a LogicalGraph from Python source code.

Two-Pass Logic: 1. Init Pass: Scans __init__ or setup to register named layers

assigned to self. Populates the node registry and provenance map.

  1. Forward Pass: Scans forward or __call__ to trace variable usage. Builds edges between registered nodes based on data flow.

graph¶

The constructed intermediate representation.

Type:

LogicalGraph

layer_registry¶

Mapping of node IDs to LogicalNodes.

Type:

Dict[str, LogicalNode]

provenance¶

Mapping of variable names to producer node IDs.

Type:

Dict[str, str]

node_map¶

Provenance registry mapping Node ID -> CST Node.

Type:

Dict[str, cst.CSTNode]

graph¶
layer_registry: Dict[str, ml_switcheroo.compiler.ir.LogicalNode]¶
provenance: Dict[str, str]¶
node_map: Dict[str, libcst.CSTNode]¶
model_name: str = 'GeneratedNet'¶
visit_ClassDef(node: libcst.ClassDef) → bool | None[source]¶

Capture the model class name.

leave_ClassDef(node: libcst.ClassDef) → None[source]¶

Exit class scope.

leave_Module(original_node: libcst.Module) → None[source]¶

Finalize graph construction.

visit_FunctionDef(node: libcst.FunctionDef) → bool | None[source]¶

Detects entry into lifecycle methods.

leave_FunctionDef(node: libcst.FunctionDef) → None[source]¶

Resets context flags upon exiting methods.

visit_Assign(node: libcst.Assign) → bool | None[source]¶

Handles assignment logic for both layer definition and data flow.

visit_Expr(node: libcst.Expr) → bool | None[source]¶

Handles standalone expression statements (e.g. func(x) without assignment). Used for 1:1 translations where top-level expressions are valid (e.g. MLIR roundtrips).

visit_Return(node: libcst.Return) → bool | None[source]¶

Handles return statements to identify Output nodes.