Architecture

ml-switcheroo is a deterministic, specification-driven transpiler designed to convert Deep Learning code between frameworks (e.g., PyTorch to JAX/Flax) with mathematical rigor.

It solves the \(O(N^2)\) translation problem by decoupling Specification (the Abstract Operation) from Implementation (the Framework API) using a Hub-and-Spoke architecture. Rather than writing translators for every pair of frameworks (Torch\(\to\)JAX, JAX\(\to\)TF, TF\(\to\)Torch), we map every framework to a central “Abstract Standard.”


🏗️ The Semantic Pivot Strategy

The conversion process is a three-step movement through an abstract intermediate state:

  1. Ingest (Source \(\to\) Hub): The system identifies a framework call (e.g., torch.permute) and maps it to an Abstract Operation (e.g., permute_dims) using the source framework’s snapshot.

  2. Pivot (Normalization): Arguments are reordered, renamed, and unpacked to match the Abstract Standard (The “Hub” signature).

  3. Project (Hub \(\to\) Target): The system looks up the implementation for the target framework (e.g., jax.numpy.transpose) and generates the corresponding AST, applying any necessary plugin logic (Argument Packing, State Injection).


🧩 1. The Knowledge Base (Hub & Spoke)

The core dataset driving the transpiler is distributed across two layers. This separation allows the “What” (Standard) to evolve independently of the “How” (Implementation).

        graph TD
    %% --- STYLE DEFINITIONS ---
    classDef default font-family:'Google Sans',color:#20344b,stroke:#20344b;
    classDef input fill:#ea4335,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px;
    classDef build fill:#4285f4,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px;
    classDef hub fill:#f9ab00,stroke:#20344b,stroke-width:2px,color:#20344b,font-family:'Google Sans',rx:5px;
    classDef spoke fill:#fff4c7,stroke:#f9ab00,stroke-width:2px,stroke-dasharray: 5 5,color:#20344b,font-family:'Google Sans',rx:5px;

    %% --- PHASE 1: DISCOVERY ---
    subgraph P1 ["1. Ingestion Phase"]
        direction TB
        STANDARDS("External Specs<br/>(ONNX / Array API)"):::input
        LIBS("Installed Libs<br/>(Torch / JAX)"):::input
        
        INSPECTOR("Inspector &<br/>Scaffolder"):::build
        
        STANDARDS --> INSPECTOR
        LIBS --> INSPECTOR
    end

    %% --- PHASE 2: STORAGE ---
    subgraph P2 ["2. Distributed Storage"]
        direction TB
        HUB[("<b>The Hub (Specs)</b><br/>semantics/*.json<br/><i>Abstract Operations</i>")]:::hub
        SPOKE[("<b>The Spokes (Overlays)</b><br/>snapshots/*_mappings.json<br/><i>Framework Variants</i>")]:::spoke
        
        %% Internal Context Link
        SPOKE -.->|" Hydrates "| HUB
    end

    %% Flow P1 -> P2 (Creates Vertical Spine)
    INSPECTOR -->|"Populate"| HUB
    INSPECTOR -->|"Populate"| SPOKE

    %% --- PHASE 3: VERIFICATION ---
    subgraph P3 ["3. Verification Phase"]
        direction TB
        TESTER("TestGen & Fuzzer"):::build
    end

    %% Flow P2 -> P3 (Continues Vertical Spine)
    HUB -.->|"Read Spec"| TESTER
    SPOKE -.->|"Read Variant"| TESTER
    

The Hub: Semantic Specifications

Located in src/ml_switcheroo/semantics/*.json. Defines WHAT an operation is.

  • Tier A (Math): k_array_api.json — Imported from the Python Array API Standard (NumPy-like ops).

  • Tier B (Neural): k_neural_net.json — Imported from ONNX Operators (Layers, Activations).

  • Tier C (Extras): k_framework_extras.json — Framework utilities, IO, and internal consensus standards.

  • Discovery: k_discovered.json — Generated by the Consensus Engine.

The Spokes: Framework Overlays

Located in src/ml_switcheroo/snapshots/{framework}_mappings.json. Defines HOW a specific framework implements the standard.

  • API Path: E.g., torch.abs, jax.numpy.abs.

  • Argument Map: E.g., {"input": "x", "dim": "axis"}.

  • Plugin Hooks: Links to complex logic (e.g., requires_plugin: "decompose_alpha").

This architecture supports Ghost Mode: The engine can transpile code even if the source or target framework libraries are not installed locally, because the API signatures are captured in these JSON snapshots.


⚡ 2. The Transpilation Engine

The ASTEngine orchestrates the conversion pipeline. It parses source code into a detailed Abstract Syntax Tree (LibCST), performs safety analysis, transforms the tree, and handles output refinement.

        graph TD
    %% --- STYLE DEFINITIONS ---
    classDef default font-family:'Google Sans',color:#20344b,stroke:#20344b;
    classDef artifact fill:#ffffff,stroke:#20344b,stroke-width:1px,color:#20344b,font-family:'Roboto Mono',stroke-dasharray: 0;
    classDef process fill:#4285f4,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px;
    classDef kb fill:#f9ab00,stroke:#20344b,stroke-width:2px,color:#20344b,font-family:'Google Sans',rx:5px;
    classDef plugin fill:#57caff,stroke:#20344b,stroke-width:2px,color:#20344b,font-family:'Google Sans',rx:5px;
    classDef output fill:#34a853,stroke:#20344b,stroke-width:2px,color:#ffffff,font-family:'Google Sans',rx:5px;

    %% --- NODES ---
    SRC("Source Code"):::artifact
    
    subgraph ENGINE ["AST Engine"]
        direction TB
        ANALYSIS("1. Safety Analysis<br/>(Purity/Deps Check)"):::process
        
        SERVER[("Semantics<br/>Manager")]:::kb
        
        REWRITER("2. Pivot Rewriter"):::process
        PLUGINS{{Plugin System}}:::plugin
        
        FIXER("3. Refinement<br/>(Import Fixer)"):::process
    end
    
    TGT("Target Code"):::output

    %% --- EDGES ---
    SRC --> ANALYSIS
    ANALYSIS --> REWRITER
    
    SERVER -.->|"Lookup API"| REWRITER
    REWRITER <-->|"Complex Logic"| PLUGINS
    
    REWRITER --> FIXER
    FIXER --> TGT
    

1. Analysis Phase

Before touching the code, the engine scans for safety violations, particularly when targeting functional frameworks like JAX.

  • PurityScanner: Detects side effects (IO, Globals, in-place list mutation) that break jit.

  • LifecycleTracker: Ensures all class attributes used in forward are initialized in __init__.

  • DependencyScanner: Checks for unmapped 3rd-party imports (e.g., pandas/cv2).

2. Rewriting Phase (PivotRewriter)

The core transformer is built on a Mixin architecture:

  • StructureMixin: Handles Class/Function definitions. It converts torch.nn.Module to flax.nnx.Module, renames forward to __call__, and injects state arguments (rngs) into constructors.

  • CallMixin: Handles function invocations. Resolves the source call to an Abstract ID, looks up the target implementation, creates argument pivots, and dispatches plugins.

  • NormalizationMixin: Handles argument type alignment (keyword vs positional).

  • AttributesMixin: Handles constant renaming (e.g., torch.float32 \(\to\) jnp.float32).

3. Refinement Phase

  • ImportFixer: An intelligent pass that scans the generated AST. It injects required imports (e.g., import jax.numpy as jnp) only if used and prunes unused source imports (e.g., import torch). It handles alias conflicts and standard naming conventions defined in SemanticsManager.

  • StructuralLinter: A final sanity check that flags any residual artifacts from the source framework that failed conversion.


🔌 3. Framework Adapters (Traits & Hierarchy)

Support for specific libraries resides in src/ml_switcheroo/frameworks/. Adapters are Python classes that provide Traits to the engine rather than hardcoded logic.

Structural Traits

Adapters define a StructuralTraits configuration object that controls syntax generation:

  • module_base: The base class for layers (e.g., "flax.nnx.Module").

  • forward_method: The inference method name ("forward" vs "call" vs "__call__").

  • inject_magic_args: Tuple of arguments to inject into signatures (e.g., [("rngs", "nnx.Rngs")]).

  • lifestyle_strip_methods: Methods to silently remove (e.g., .cuda(), .detach()).

Hierarchy & Flavours

The system supports hierarchical framework definitions:

  • JAX Core (Level 1): Provides math mappings (jnp), optimizations (optax), and serialization (orbax).

  • Flax NNX (Level 2): Inherits from JAX Core via inherits_from="jax", gaining all math/opt capabilities while adding Neural Network structural traits (nnx.Module).


🤖 4. Discovery & Consensus

The system includes an automated pipeline to grow the Knowledge Base.

Ghost Protocol

The GhostInspector can introspect APIs of installed libraries (Live Mode) or load API signatures from JSON files (Ghost Mode). This allows the Scaffolder and Consensus Engine to run in restricted environments (CI, WebAssembly) without requiring heavy dependencies like PyTorch or TensorFlow to be installed.

Consensus Engine

Located in src/ml_switcheroo/discovery/consensus.py. It implements a voting algorithm to discover “Unofficial Standards”.

  1. Cluster: Groups APIs across frameworks by normalized name (e.g., HuberLoss, huber_loss, Huber \(\to\) Cluster “Huber”).

  2. Align: Analyzes signatures to find common parameters (e.g., if 3/4 frameworks use epsilon, it becomes a standard argument).

  3. Persist: Writes the new standard to k_discovered.json.


🧠 5. Plugin System

For operations that cannot be mapped 1:1 (e.g., architectural differences or complex argument logic), the engine delegates to Hook Functions in src/ml_switcheroo/plugins/.

  • HookContext: Provides plugins with access to the SemanticsManager and global configuration.

  • Common Plugins:

    • decompose_alpha: add(x, y, alpha=2) \(\to\) add(x, y*2).

    • pack_varargs: permute(x, 0, 1) \(\to\) transpose(x, axes=(0, 1)).

    • state_flag_injection: Injects training=True/False kwargs based on context calls like .eval().

    • rng_threading: Transforms global seed logic to explicitly threaded PRNG keys for JAX.