ml-switcheroo ππ¦ΒΆ
A Deterministic, Specification-Driven Transpiler for Deep Learning Frameworks.
ml-switcheroo is a rigorous AST-based transpiler designed to convert Deep Learning code between frameworks (e.g., * PyTorch* \(\leftrightarrow\) JAX, Keras \(\to\) TensorFlow) without hallucination.
It uses a Hub-and-Spoke architecture to solve the \(O(N^2)\) translation problem. Instead of writing translators for
every pair of frameworks, ml-switcheroo maps all frameworks to a central Abstract Standard (Hub). This allows
for βZero-Editβ support for new frameworks via isolated JSON snapshots (Spokes).
π Key FeaturesΒΆ
π« No Hallucinations: Uses static analysis (AST) and deterministic mapping rules. If it compiles, itβs mathematically grounded.
π ODL (Operation Definition Language): Define new mathematical operations using a simple YAML syntax without writing Python AST code.
π Hub-and-Spoke Architecture: Decouples the semantic definition of an operation (e.g.,
Conv2d) from its implementation (e.g.,torch.nn.Conv2d).π» Ghost Mode: Can analyze and transpile code for frameworks not installed on the local machine using cached API snapshots.
π‘οΈ Safety Logic: Automatically detects side-effects (IO, globals) that break functional compilation (JIT) via the Purity Scanner.
𧬠Structural Rewriting: Handles complex transformations for class hierarchies (e.g.,
nn.Module\(\leftrightarrow\)flax.nnx.Module), random number threading, and state management.
ποΈ ArchitectureΒΆ
Code is parsed into an Abstract Syntax Tree (AST), analyzed for safety, pivoted through the Abstract Standard, and reconstructed for the target framework.
graph TD
%% Theme
classDef default font-family: 'Google Sans Normal', color: #20344b, stroke: #20344b, stroke-width: 1px;
classDef source fill: #ea4335, stroke: #20344b, stroke-width: 2px, color: #ffffff, font-family: 'Google Sans Medium', rx: 5px;
classDef engine fill: #4285f4, stroke: #20344b, stroke-width: 2px, color: #ffffff, font-family: 'Google Sans Medium', rx: 5px;
classDef hub fill: #f9ab00, stroke: #20344b, stroke-width: 2px, color: #20344b, font-family: 'Google Sans Medium', rx: 5px;
classDef spoke fill: #fff4c7, stroke: #f9ab00, stroke-width: 2px, stroke-dasharray: 5 5, color: #20344b, font-family: 'Google Sans Medium', rx: 5px;
classDef target fill: #34a853, stroke: #20344b, stroke-width: 2px, color: #ffffff, font-family: 'Google Sans Medium', rx: 5px;
classDef codeBlock fill: #ffffff, stroke: #20344b, stroke-width: 1px, font-family: 'Roboto Mono Normal', text-align: left, font-size: 12px;
SRC_HEADER("<b>0. Source Code</b><br/>(e.g., PyTorch)"):::source
PARSER("<b>1. Analysis Phase</b><br/>Parsing, Purity Check,<br/>Lifecycle Scans"):::engine
SRC_HEADER --> PARSER
%% The Knowledge Base
subgraph KB [Distributed Knowledge Base]
direction TB
SPECS[("<b>The Hub (Specs)</b><br/>semantics/*.json<br/><i>Abstract Operations</i>")]:::hub
MAPS[("<b>The Spokes (Overlays)</b><br/>snapshots/*_mappings.json<br/><i>Framework Variants</i>")]:::spoke
MAPS -.->|" Hydrates "| SPECS
end
REWRITER("<b>2. Pivot Rewriter</b><br/><i>Semantic Translation</i>"):::engine
KB -.->|" Lookup API "| REWRITER
PARSER --> REWRITER
PIVOT_LOGIC("<b>1. Ingest:</b> torch.abs(x)<br/><b>2. Pivot:</b> Abs(x) [Standard]<br/><b>3. Project:</b> jnp.abs(x)"):::codeBlock
REWRITER --- PIVOT_LOGIC
FIXER("<b>3. Refinement</b><br/>Import Injection & Pruning"):::engine
PIVOT_LOGIC --> FIXER
TGT_HEADER("<b>4. Target Code</b><br/>(e.g., JAX/Flax)"):::target
FIXER --> TGT_HEADER
π¦ InstallationΒΆ
# Install form source
pip install .
# Install with testing dependencies (for running the fuzzer/verification)
pip install ".[test]"
π οΈ CLI UsageΒΆ
The ml_switcheroo tool provides a suite of commands for conversion, auditing, and knowledge base maintenance.
1. Transpilation (convert)ΒΆ
Convert a file or directory from one framework to another.
# Convert a PyTorch model to JAX (Flax NNX)
ml_switcheroo convert ./models/resnet.py \
--source torch \
--target jax \
--out ./resnet_jax.py
# Convert an entire directory, enabling strict mode
# Strict mode fails if an API mapping is missing, rather than passing it through.
ml_switcheroo convert ./src/ --out ./dst/ --strict
2. Codebase Audit (audit)ΒΆ
Analyze a codebase to check βTranslation Readinessβ. This scans API calls and checks coverage against the Knowledge Base.
ml_switcheroo audit ./my_project/ --roots torch
3. Verification (ci)ΒΆ
The CI command runs the built-in Fuzzer. It generates random inputs (Tensors, Scalars) based on Type Hints in the Spec, feeds them into both Source and Target frameworks, and mathematically verifies equivalence.
# Run full verification suite on the Knowledge Base
ml_switcheroo ci
# Generate a lockfile of verified operations
ml_switcheroo ci --json-report verified_ops.json
4. Knowledge Discovery (scaffold & wizard)ΒΆ
Populate the Knowledge Base automatically by scanning installed libraries.
# 1. Scaffold: Scan installed libs and generate JSON mappings via heuristics
ml_switcheroo scaffold --frameworks torch jax
# 2. Wizard: Interactive tool to manualy categorize obscure APIs
ml_switcheroo wizard torch
5. Operation Definition (define)ΒΆ
Inject new operations into the Knowledge Base using declarative YAML files.
ml_switcheroo define my_ops.yaml
β API Support MatrixΒΆ
Supported Frameworks via Zero-Edit Adapters:
Framework |
Status |
Specialized Features Supported |
|---|---|---|
PyTorch |
π’ Primary |
Source/Target, |
JAX / Flax |
π’ Primary |
Source/Target ( |
TensorFlow |
π΅ Beta |
Keras Layer conversion, |
NumPy |
π‘ Stable |
Array operations, fallback target for pure math |
Keras 3 |
π΅ Beta |
Multi-backend layers, |
Apple MLX |
π΅ Beta |
|
PaxML |
βͺ Alpha |
|
To view the live compatibility table for your installed version:
ml_switcheroo matrix
π§ Advanced CapabilitiesΒΆ
Functional UnwrappingΒΆ
Frameworks like JAX require pure functions. ml-switcheroo automatically detects stateful imperative patterns (like
drop_last=True in loops or in-place lists) and warns via the Purity Scanner.
When converting Flax NNX (functional) to Torch (OO), it unwraps layer.apply(params, x) calls into standard
layer(x) calls using Assign restructuring.
State Injection (RNG Threading)ΒΆ
When converting PyTorch (global RNG state) to JAX (explicit RNG keys), the engine:
Detects stochastic operations (Dropout, Random init) via the Analyzer.
Injects an
rngargument into function signatures.Injects
rng, key = jax.random.split(rng)preambles.Threads the
keyargument into relevant function calls.
Intelligent Import ManagementΒΆ
The Import Fixer does not just swap strings; it analyzes usage logic:
Removes unused source imports (
import torch).Injects required target imports (
import jax.numpy as jnp) only if referenced.Handles alias conflicts (
import torch as t).
π ExtensibilityΒΆ
ml-switcheroo is designed to be extended without modifying the core engine.
Add Operations (ODL): Use the Operation Definition Language (YAML) to define math/neural ops. This is the recommended way to add missing functionality.
operation: "Erf" std_args: [ "input" ] variants: torch: { api: "torch.erf" } jax: { api: "jax.lax.erf" }
See EXTENDING_WITH_DSL.md for the full guide. Alternative to the YAML DSL you can manually update:
src/ml_switcheroo/semantics/standards_internal.pyandsrc/ml_switcheroo/frameworks/*.py(for torch, mlx, tensorflow, jax, etc.)
Add a Framework: Create a class inheriting
FrameworkAdapterinsrc/ml_switcheroo/frameworks/.Add Logic: Write a localized hook in
src/ml_switcheroo/plugins/(e.g., for custom layer rewrites likeMultiHeadAttentionpacking).
See EXTENDING.md for architectural details on Adapters and Plugins.