Architecture¶
ml-switcheroo is a deterministic, specification-driven transpiler designed to convert Deep Learning code between frameworks (e.g., PyTorch to JAX/Flax) with mathematical rigor.
It solves the \(O(N^2)\) translation problem by decoupling Specification (the Abstract Operation) from Implementation (the Framework API) using a Hub-and-Spoke architecture. Rather than writing translators for every pair of frameworks (Torch\(\to\)JAX, JAX\(\to\)TF, TF\(\to\)Torch), we map every framework to a central “Abstract Standard.”
🏗️ The Semantic Pivot Strategy¶
The conversion process is a three-step movement through an abstract intermediate state:
Ingest (Source \(\to\) Hub): The system identifies a framework call (e.g.,
torch.permute) and maps it to an Abstract Operation (e.g.,permute_dims) using the source framework’s snapshot.Pivot (Normalization): Arguments are reordered, renamed, and unpacked to match the Abstract Standard (The “Hub” signature).
Project (Hub \(\to\) Target): The system looks up the implementation for the target framework (e.g.,
jax.numpy.transpose) and generates the corresponding AST, applying any necessary plugin logic (Argument Packing, State Injection).
🧩 1. The Knowledge Base (Hub & Spoke)¶
The core dataset driving the transpiler is distributed across two layers. This separation allows the “What” (Standard) to evolve independently of the “How” (Implementation).
graph TD
%% --- STYLE DEFINITIONS ---
classDef default font-family:'Google Sans',color:#20344b,stroke:#20344b;
classDef input fill:#ea4335,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px;
classDef build fill:#4285f4,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px;
classDef hub fill:#f9ab00,stroke:#20344b,stroke-width:2px,color:#20344b,font-family:'Google Sans',rx:5px;
classDef spoke fill:#fff4c7,stroke:#f9ab00,stroke-width:2px,stroke-dasharray: 5 5,color:#20344b,font-family:'Google Sans',rx:5px;
%% --- PHASE 1: DISCOVERY ---
subgraph P1 ["1. Ingestion Phase"]
direction TB
STANDARDS("External Specs<br/>(ONNX / Array API)"):::input
LIBS("Installed Libs<br/>(Torch / JAX)"):::input
INSPECTOR("Inspector &<br/>Scaffolder"):::build
STANDARDS --> INSPECTOR
LIBS --> INSPECTOR
end
%% --- PHASE 2: STORAGE ---
subgraph P2 ["2. Distributed Storage"]
direction TB
HUB[("<b>The Hub (Specs)</b><br/>semantics/*.json<br/><i>Abstract Operations</i>")]:::hub
SPOKE[("<b>The Spokes (Overlays)</b><br/>snapshots/*_mappings.json<br/><i>Framework Variants</i>")]:::spoke
%% Internal Context Link
SPOKE -.->|" Hydrates "| HUB
end
%% Flow P1 -> P2 (Creates Vertical Spine)
INSPECTOR -->|"Populate"| HUB
INSPECTOR -->|"Populate"| SPOKE
%% --- PHASE 3: VERIFICATION ---
subgraph P3 ["3. Verification Phase"]
direction TB
TESTER("TestGen & Fuzzer"):::build
end
%% Flow P2 -> P3 (Continues Vertical Spine)
HUB -.->|"Read Spec"| TESTER
SPOKE -.->|"Read Variant"| TESTER
The Hub: Semantic Specifications¶
Located in src/ml_switcheroo/semantics/*.json. Defines WHAT an operation is.
Tier A (Math):
k_array_api.json— Imported from the Python Array API Standard (NumPy-like ops).Tier B (Neural):
k_neural_net.json— Imported from ONNX Operators (Layers, Activations).Tier C (Extras):
k_framework_extras.json— Framework utilities, IO, and internal consensus standards.Discovery:
k_discovered.json— Generated by the Consensus Engine.
The Spokes: Framework Overlays¶
Located in src/ml_switcheroo/snapshots/{framework}_mappings.json. Defines HOW a specific framework implements the standard.
API Path: E.g.,
torch.abs,jax.numpy.abs.Argument Map: E.g.,
{"input": "x", "dim": "axis"}.Plugin Hooks: Links to complex logic (e.g.,
requires_plugin: "decompose_alpha").
This architecture supports Ghost Mode: The engine can transpile code even if the source or target framework libraries are not installed locally, because the API signatures are captured in these JSON snapshots.
⚡ 2. The Transpilation Engine¶
The ASTEngine orchestrates the conversion pipeline. It parses source code into a detailed Abstract Syntax Tree (LibCST), performs safety analysis, transforms the tree, and handles output refinement.
graph TD
%% --- STYLE DEFINITIONS ---
classDef default font-family:'Google Sans',color:#20344b,stroke:#20344b;
classDef artifact fill:#ffffff,stroke:#20344b,stroke-width:1px,color:#20344b,font-family:'Roboto Mono',stroke-dasharray: 0;
classDef process fill:#4285f4,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px;
classDef kb fill:#f9ab00,stroke:#20344b,stroke-width:2px,color:#20344b,font-family:'Google Sans',rx:5px;
classDef plugin fill:#57caff,stroke:#20344b,stroke-width:2px,color:#20344b,font-family:'Google Sans',rx:5px;
classDef output fill:#34a853,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px;
%% --- NODES ---
SRC("Source Code"):::artifact
subgraph ENGINE ["AST Engine"]
direction TB
ANALYSIS("1. Safety Analysis<br/>(Purity/Deps Check)"):::process
SERVER[("Semantics<br/>Manager")]:::kb
REWRITER("2. Pivot Rewriter"):::process
PLUGINS{{Plugin System}}:::plugin
FIXER("3. Refinement<br/>(Import Fixer)"):::process
end
TGT("Target Code"):::output
%% --- EDGES ---
SRC --> ANALYSIS
ANALYSIS --> REWRITER
SERVER -.->|"Lookup API"| REWRITER
REWRITER <-->|"Complex Logic"| PLUGINS
REWRITER --> FIXER
FIXER --> TGT
1. Analysis Phase¶
Before touching the code, the engine scans for safety violations, particularly when targeting functional frameworks like JAX.
PurityScanner: Detects side effects (IO, Globals, in-place list mutation) that break
jit.LifecycleTracker: Ensures all class attributes used in
forwardare initialized in__init__.DependencyScanner: Checks for unmapped 3rd-party imports (e.g., pandas/cv2).
2. Rewriting Phase (PivotRewriter)¶
The core transformer is built on a Mixin architecture:
StructureMixin: Handles Class/Function definitions. It converts
torch.nn.Moduletoflax.nnx.Module, renamesforwardto__call__, and injects state arguments (rngs) into constructors.CallMixin: Handles function invocations. Resolves the source call to an Abstract ID, looks up the target implementation, creates argument pivots, and dispatches plugins.
NormalizationMixin: Handles argument type alignment (keyword vs positional).
AttributesMixin: Handles constant renaming (e.g.,
torch.float32\(\to\)jnp.float32).
3. Refinement Phase¶
ImportFixer: An intelligent pass that scans the generated AST. It injects required imports (e.g.,
import jax.numpy as jnp) only if used and prunes unused source imports (e.g.,import torch). It handles alias conflicts and standard naming conventions defined inSemanticsManager.StructuralLinter: A final sanity check that flags any residual artifacts from the source framework that failed conversion.
🔌 3. Framework Adapters (Traits & Hierarchy)¶
Support for specific libraries resides in src/ml_switcheroo/frameworks/. Adapters are Python classes that provide Traits to the engine rather than hardcoded logic.
Structural Traits¶
Adapters define a StructuralTraits configuration object that controls syntax generation:
module_base: The base class for layers (e.g.,"flax.nnx.Module").forward_method: The inference method name ("forward"vs"call"vs"__call__").inject_magic_args: Tuple of arguments to inject into signatures (e.g.,[("rngs", "nnx.Rngs")]).lifestyle_strip_methods: Methods to silently remove (e.g.,.cuda(),.detach()).
Hierarchy & Flavours¶
The system supports hierarchical framework definitions:
JAX Core (Level 1): Provides math mappings (
jnp), optimizations (optax), and serialization (orbax).Flax NNX (Level 2): Inherits from JAX Core via
inherits_from="jax", gaining all math/opt capabilities while adding Neural Network structural traits (nnx.Module).
🤖 4. Discovery & Consensus¶
The system includes an automated pipeline to grow the Knowledge Base.
Ghost Protocol¶
The GhostInspector can introspect APIs of installed libraries (Live Mode) or load API signatures from JSON files (Ghost Mode). This allows the Scaffolder and Consensus Engine to run in restricted environments (CI, WebAssembly) without requiring heavy dependencies like PyTorch or TensorFlow to be installed.
Consensus Engine¶
Located in src/ml_switcheroo/discovery/consensus.py. It implements a voting algorithm to discover “Unofficial Standards”.
Cluster: Groups APIs across frameworks by normalized name (e.g.,
HuberLoss,huber_loss,Huber\(\to\) Cluster “Huber”).Align: Analyzes signatures to find common parameters (e.g., if 3/4 frameworks use
epsilon, it becomes a standard argument).Persist: Writes the new standard to
k_discovered.json.
🧠 5. Plugin System¶
For operations that cannot be mapped 1:1 (e.g., architectural differences or complex argument logic), the engine delegates to Hook Functions in src/ml_switcheroo/plugins/.
HookContext: Provides plugins with access to the
SemanticsManagerand global configuration.Common Plugins:
decompose_alpha:add(x, y, alpha=2)\(\to\)add(x, y*2).pack_varargs:permute(x, 0, 1)\(\to\)transpose(x, axes=(0, 1)).state_flag_injection: Injectstraining=True/Falsekwargs based on context calls like.eval().rng_threading: Transforms global seed logic to explicitly threaded PRNG keys for JAX.