Architecture ============ **ml-switcheroo** is a deterministic, specification-driven transpiler designed to convert Deep Learning code between frameworks (e.g., PyTorch to JAX/Flax) with mathematical rigor. It solves the $O(N^2)$ translation problem by decoupling **Specification** (the Abstract Operation) from **Implementation** (the Framework API) using a **Hub-and-Spoke** architecture. Rather than writing translators for every pair of frameworks (Torch$\to$JAX, JAX$\to$TF, TF$\to$Torch), we map every framework to a central "Abstract Standard." --- ## 🏗️ The Semantic Pivot Strategy The conversion process is a three-step movement through an abstract intermediate state: 1. **Ingest (Source $\to$ Hub):** The system identifies a framework call (e.g., `torch.permute`) and maps it to an **Abstract Operation** (e.g., `permute_dims`) using the source framework's snapshot. 2. **Pivot (Normalization):** Arguments are reordered, renamed, and unpacked to match the Abstract Standard (The "Hub" signature). 3. **Project (Hub $\to$ Target):** The system looks up the implementation for the target framework (e.g., `jax.numpy.transpose`) and generates the corresponding AST, applying any necessary plugin logic (Argument Packing, State Injection). --- ## 🧩 1. The Knowledge Base (Hub & Spoke) The core dataset driving the transpiler is distributed across two layers. This separation allows the "What" (Standard) to evolve independently of the "How" (Implementation). ```mermaid graph TD %% --- STYLE DEFINITIONS --- classDef default font-family:'Google Sans',color:#20344b,stroke:#20344b; classDef input fill:#ea4335,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px; classDef build fill:#4285f4,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px; classDef hub fill:#f9ab00,stroke:#20344b,stroke-width:2px,color:#20344b,font-family:'Google Sans',rx:5px; classDef spoke fill:#fff4c7,stroke:#f9ab00,stroke-width:2px,stroke-dasharray: 5 5,color:#20344b,font-family:'Google Sans',rx:5px; %% --- PHASE 1: DISCOVERY --- subgraph P1 ["1. Ingestion Phase"] direction TB STANDARDS("External Specs
(ONNX / Array API)"):::input LIBS("Installed Libs
(Torch / JAX)"):::input INSPECTOR("Inspector &
Scaffolder"):::build STANDARDS --> INSPECTOR LIBS --> INSPECTOR end %% --- PHASE 2: STORAGE --- subgraph P2 ["2. Distributed Storage"] direction TB HUB[("The Hub (Specs)
semantics/*.json
Abstract Operations")]:::hub SPOKE[("The Spokes (Overlays)
snapshots/*_mappings.json
Framework Variants")]:::spoke %% Internal Context Link SPOKE -.->|" Hydrates "| HUB end %% Flow P1 -> P2 (Creates Vertical Spine) INSPECTOR -->|"Populate"| HUB INSPECTOR -->|"Populate"| SPOKE %% --- PHASE 3: VERIFICATION --- subgraph P3 ["3. Verification Phase"] direction TB TESTER("TestGen & Fuzzer"):::build end %% Flow P2 -> P3 (Continues Vertical Spine) HUB -.->|"Read Spec"| TESTER SPOKE -.->|"Read Variant"| TESTER ``` ### The Hub: Semantic Specifications Located in `src/ml_switcheroo/semantics/*.json`. Defines **WHAT** an operation is. * **Tier A (Math):** `k_array_api.json` — Imported from the Python Array API Standard (NumPy-like ops). * **Tier B (Neural):** `k_neural_net.json` — Imported from ONNX Operators (Layers, Activations). * **Tier C (Extras):** `k_framework_extras.json` — Framework utilities, IO, and internal consensus standards. * **Discovery:** `k_discovered.json` — Generated by the Consensus Engine. ### The Spokes: Framework Overlays Located in `src/ml_switcheroo/snapshots/{framework}_mappings.json`. Defines **HOW** a specific framework implements the standard. * **API Path:** E.g., `torch.abs`, `jax.numpy.abs`. * **Argument Map:** E.g., `{"input": "x", "dim": "axis"}`. * **Plugin Hooks:** Links to complex logic (e.g., `requires_plugin: "decompose_alpha"`). This architecture supports **Ghost Mode**: The engine can transpile code even if the source or target framework libraries are not installed locally, because the API signatures are captured in these JSON snapshots. --- ## ⚡ 2. The Transpilation Engine The `ASTEngine` orchestrates the conversion pipeline. It parses source code into a detailed Abstract Syntax Tree (LibCST), performs safety analysis, transforms the tree, and handles output refinement. ```mermaid graph TD %% --- STYLE DEFINITIONS --- classDef default font-family:'Google Sans',color:#20344b,stroke:#20344b; classDef artifact fill:#ffffff,stroke:#20344b,stroke-width:1px,color:#20344b,font-family:'Roboto Mono',stroke-dasharray: 0; classDef process fill:#4285f4,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px; classDef kb fill:#f9ab00,stroke:#20344b,stroke-width:2px,color:#20344b,font-family:'Google Sans',rx:5px; classDef plugin fill:#57caff,stroke:#20344b,stroke-width:2px,color:#20344b,font-family:'Google Sans',rx:5px; classDef output fill:#34a853,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px; %% --- NODES --- SRC("Source Code"):::artifact subgraph ENGINE ["AST Engine"] direction TB ANALYSIS("1. Safety Analysis
(Purity/Deps Check)"):::process SERVER[("Semantics
Manager")]:::kb REWRITER("2. Pivot Rewriter"):::process PLUGINS{{Plugin System}}:::plugin FIXER("3. Refinement
(Import Fixer)"):::process end TGT("Target Code"):::output %% --- EDGES --- SRC --> ANALYSIS ANALYSIS --> REWRITER SERVER -.->|"Lookup API"| REWRITER REWRITER <-->|"Complex Logic"| PLUGINS REWRITER --> FIXER FIXER --> TGT ``` ### 1. Analysis Phase Before touching the code, the engine scans for safety violations, particularly when targeting functional frameworks like JAX. * **PurityScanner:** Detects side effects (IO, Globals, in-place list mutation) that break `jit`. * **LifecycleTracker:** Ensures all class attributes used in `forward` are initialized in `__init__`. * **DependencyScanner:** Checks for unmapped 3rd-party imports (e.g., pandas/cv2). ### 2. Rewriting Phase (`PivotRewriter`) The core transformer is built on a Mixin architecture: * **StructureMixin:** Handles Class/Function definitions. It converts `torch.nn.Module` to `flax.nnx.Module`, renames `forward` to `__call__`, and injects state arguments (`rngs`) into constructors. * **CallMixin:** Handles function invocations. Resolves the source call to an Abstract ID, looks up the target implementation, creates argument pivots, and dispatches plugins. * **NormalizationMixin:** Handles argument type alignment (keyword vs positional). * **AttributesMixin:** Handles constant renaming (e.g., `torch.float32` $\to$ `jnp.float32`). ### 3. Refinement Phase * **ImportFixer:** An intelligent pass that scans the *generated* AST. It injects required imports (e.g., `import jax.numpy as jnp`) only if used and prunes unused source imports (e.g., `import torch`). It handles alias conflicts and standard naming conventions defined in `SemanticsManager`. * **StructuralLinter:** A final sanity check that flags any residual artifacts from the source framework that failed conversion. --- ## 🔌 3. Framework Adapters (Traits & Hierarchy) Support for specific libraries resides in `src/ml_switcheroo/frameworks/`. Adapters are Python classes that provide **Traits** to the engine rather than hardcoded logic. ### Structural Traits Adapters define a `StructuralTraits` configuration object that controls syntax generation: * `module_base`: The base class for layers (e.g., `"flax.nnx.Module"`). * `forward_method`: The inference method name (`"forward"` vs `"call"` vs `"__call__"`). * `inject_magic_args`: Tuple of arguments to inject into signatures (e.g., `[("rngs", "nnx.Rngs")]`). * `lifestyle_strip_methods`: Methods to silently remove (e.g., `.cuda()`, `.detach()`). ### Hierarchy & Flavours The system supports hierarchical framework definitions: * **JAX Core (Level 1):** Provides math mappings (`jnp`), optimizations (`optax`), and serialization (`orbax`). * **Flax NNX (Level 2):** Inherits from JAX Core via `inherits_from="jax"`, gaining all math/opt capabilities while adding Neural Network structural traits (`nnx.Module`). --- ## 🤖 4. Discovery & Consensus The system includes an automated pipeline to grow the Knowledge Base. ### Ghost Protocol The `GhostInspector` can introspect APIs of installed libraries (Live Mode) or load API signatures from JSON files (Ghost Mode). This allows the Scaffolder and Consensus Engine to run in restricted environments (CI, WebAssembly) without requiring heavy dependencies like PyTorch or TensorFlow to be installed. ### Consensus Engine Located in `src/ml_switcheroo/discovery/consensus.py`. It implements a voting algorithm to discover "Unofficial Standards". 1. **Cluster:** Groups APIs across frameworks by normalized name (e.g., `HuberLoss`, `huber_loss`, `Huber` $\to$ Cluster "Huber"). 2. **Align:** Analyzes signatures to find common parameters (e.g., if 3/4 frameworks use `epsilon`, it becomes a standard argument). 3. **Persist:** Writes the new standard to `k_discovered.json`. --- ## 🧠 5. Plugin System For operations that cannot be mapped 1:1 (e.g., architectural differences or complex argument logic), the engine delegates to **Hook Functions** in `src/ml_switcheroo/plugins/`. * **HookContext:** Provides plugins with access to the `SemanticsManager` and global configuration. * **Common Plugins:** * `decompose_alpha`: `add(x, y, alpha=2)` $\to$ `add(x, y*2)`. * `pack_varargs`: `permute(x, 0, 1)` $\to$ `transpose(x, axes=(0, 1))`. * `state_flag_injection`: Injects `training=True/False` kwargs based on context calls like `.eval()`. * `rng_threading`: Transforms global seed logic to explicitly threaded PRNG keys for JAX.