ml-switcheroo 🔄🦘
==================
**A Deterministic, Specification-Driven Transpiler for Deep Learning Frameworks.**
[](https://opensource.org/license/apache-2-0)
[](https://www.python.org/downloads/)
[](https://github.com/astral-sh/ruff)
**ml-switcheroo** is a rigorous AST-based transpiler designed to convert Deep Learning code between frameworks (e.g., *
*PyTorch** $\leftrightarrow$ **JAX**, **Keras** $\to$ **TensorFlow**) without hallucination.
It uses a **Hub-and-Spoke** architecture to solve the $O(N^2)$ translation problem. Instead of writing translators for
every pair of frameworks, `ml-switcheroo` maps all frameworks to a central **Abstract Standard** (Hub). This allows
for "Zero-Edit" support for new frameworks via isolated JSON snapshots (Spokes).
---
## 🚀 Key Features
* **🚫 No Hallucinations**: Uses static analysis (AST) and deterministic mapping rules. If it compiles, it's
mathematically grounded.
* **📝 ODL (Operation Definition Language)**: Define new mathematical operations using a simple YAML syntax without
writing Python AST code.
* **🔌 Hub-and-Spoke Architecture**: Decouples the *semantic definition* of an operation (e.g., `Conv2d`) from its
*implementation* (e.g., `torch.nn.Conv2d`).
* **👻 Ghost Mode**: Can analyze and transpile code for frameworks *not installed* on the local machine using cached API
snapshots.
* **🛡️ Safety Logic**: Automatically detects side-effects (IO, globals) that break functional compilation (JIT) via the
**Purity Scanner**.
* **🧬 Structural Rewriting**: Handles complex transformations for class hierarchies (e.g., `nn.Module` $\leftrightarrow$
`flax.nnx.Module`), random number threading, and state management.
---
## 🏗️ Architecture
Code is parsed into an Abstract Syntax Tree (AST), analyzed for safety, pivoted through the Abstract Standard, and
reconstructed for the target framework.
```mermaid
graph TD
%% Theme
classDef default font-family: 'Google Sans Normal', color: #20344b, stroke: #20344b, stroke-width: 1px;
classDef source fill: #ea4335, stroke: #20344b, stroke-width: 2px, color: #ffffff, font-family: 'Google Sans Medium', rx: 5px;
classDef engine fill: #4285f4, stroke: #20344b, stroke-width: 2px, color: #ffffff, font-family: 'Google Sans Medium', rx: 5px;
classDef hub fill: #f9ab00, stroke: #20344b, stroke-width: 2px, color: #20344b, font-family: 'Google Sans Medium', rx: 5px;
classDef spoke fill: #fff4c7, stroke: #f9ab00, stroke-width: 2px, stroke-dasharray: 5 5, color: #20344b, font-family: 'Google Sans Medium', rx: 5px;
classDef target fill: #34a853, stroke: #20344b, stroke-width: 2px, color: #ffffff, font-family: 'Google Sans Medium', rx: 5px;
classDef codeBlock fill: #ffffff, stroke: #20344b, stroke-width: 1px, font-family: 'Roboto Mono Normal', text-align: left, font-size: 12px;
SRC_HEADER("0. Source Code
(e.g., PyTorch)"):::source
PARSER("1. Analysis Phase
Parsing, Purity Check,
Lifecycle Scans"):::engine
SRC_HEADER --> PARSER
%% The Knowledge Base
subgraph KB [Distributed Knowledge Base]
direction TB
SPECS[("The Hub (Specs)
semantics/*.json
Abstract Operations")]:::hub
MAPS[("The Spokes (Overlays)
snapshots/*_mappings.json
Framework Variants")]:::spoke
MAPS -.->|" Hydrates "| SPECS
end
REWRITER("2. Pivot Rewriter
Semantic Translation"):::engine
KB -.->|" Lookup API "| REWRITER
PARSER --> REWRITER
PIVOT_LOGIC("1. Ingest: torch.abs(x)
2. Pivot: Abs(x) [Standard]
3. Project: jnp.abs(x)"):::codeBlock
REWRITER --- PIVOT_LOGIC
FIXER("3. Refinement
Import Injection & Pruning"):::engine
PIVOT_LOGIC --> FIXER
TGT_HEADER("4. Target Code
(e.g., JAX/Flax)"):::target
FIXER --> TGT_HEADER
```
---
## 📦 Installation
```bash
# Install form source
pip install .
# Install with testing dependencies (for running the fuzzer/verification)
pip install ".[test]"
```
---
## 🛠️ CLI Usage
The `ml_switcheroo` tool provides a suite of commands for conversion, auditing, and knowledge base maintenance.
### 1. Transpilation (`convert`)
Convert a file or directory from one framework to another.
```bash
# Convert a PyTorch model to JAX (Flax NNX)
ml_switcheroo convert ./models/resnet.py \
--source torch \
--target jax \
--out ./resnet_jax.py
# Convert an entire directory, enabling strict mode
# Strict mode fails if an API mapping is missing, rather than passing it through.
ml_switcheroo convert ./src/ --out ./dst/ --strict
```
### 2. Codebase Audit (`audit`)
Analyze a codebase to check "Translation Readiness". This scans API calls and checks coverage against the Knowledge
Base.
```bash
ml_switcheroo audit ./my_project/ --roots torch
```
### 3. Verification (`ci`)
The CI command runs the built-in **Fuzzer**. It generates random inputs (Tensors, Scalars) based on Type Hints in the
Spec, feeds them into both Source and Target frameworks, and mathematically verifies equivalence.
```bash
# Run full verification suite on the Knowledge Base
ml_switcheroo ci
# Generate a lockfile of verified operations
ml_switcheroo ci --json-report verified_ops.json
```
### 4. Knowledge Discovery (`scaffold` & `wizard`)
Populate the Knowledge Base automatically by scanning installed libraries.
```bash
# 1. Scaffold: Scan installed libs and generate JSON mappings via heuristics
ml_switcheroo scaffold --frameworks torch jax
# 2. Wizard: Interactive tool to manualy categorize obscure APIs
ml_switcheroo wizard torch
```
### 5. Operation Definition (`define`)
Inject new operations into the Knowledge Base using declarative YAML files.
```bash
ml_switcheroo define my_ops.yaml
```
---
## ✅ API Support Matrix
Supported Frameworks via **Zero-Edit Adapters**:
| Framework | Status | Specialized Features Supported |
|:---------------|:----------:|:-----------------------------------------------------------------------|
| **PyTorch** | 🟢 Primary | Source/Target, `nn.Module`, `functional`, Optimizers, DataLoaders |
| **JAX / Flax** | 🟢 Primary | Source/Target (`flax.nnx`), `vmap`, `grad`, `jit`, Orbax Checkpointing |
| **TensorFlow** | 🔵 Beta | Keras Layer conversion, `tf.data`, IO operations |
| **NumPy** | 🟡 Stable | Array operations, fallback target for pure math |
| **Keras 3** | 🔵 Beta | Multi-backend layers, `keras.ops` math |
| **Apple MLX** | 🔵 Beta | `mlx.nn` layers, `mlx.core` array ops, Optimizers |
| **PaxML** | ⚪ Alpha | `praxis` layer structure translation |
To view the live compatibility table for your installed version:
```bash
ml_switcheroo matrix
```
---
## 🧠 Advanced Capabilities
### Functional Unwrapping
Frameworks like **JAX** require pure functions. ml-switcheroo automatically detects stateful imperative patterns (like
`drop_last=True` in loops or in-place lists) and warns via the **Purity Scanner**.
When converting **Flax NNX** (functional) to **Torch** (OO), it unwraps `layer.apply(params, x)` calls into standard
`layer(x)` calls using `Assign` restructuring.
### State Injection (RNG Threading)
When converting **PyTorch** (global RNG state) to **JAX** (explicit RNG keys), the engine:
1. Detects stochastic operations (Dropout, Random init) via the **Analyzer**.
2. Injects an `rng` argument into function signatures.
3. Injects `rng, key = jax.random.split(rng)` preambles.
4. Threads the `key` argument into relevant function calls.
### Intelligent Import Management
The **Import Fixer** does not just swap strings; it analyzes usage logic:
* Removes unused source imports (`import torch`).
* Injects required target imports (`import jax.numpy as jnp`) only if referenced.
* Handles alias conflicts (`import torch as t`).
---
## 🔌 Extensibility
ml-switcheroo is designed to be extended without modifying the core engine.
1. **Add Operations (ODL)**: Use the **Operation Definition Language (YAML)** to define math/neural ops. This is the
recommended way to add missing functionality.
```yaml
operation: "Erf"
std_args: [ "input" ]
variants:
torch: { api: "torch.erf" }
jax: { api: "jax.lax.erf" }
```
See [EXTENDING_WITH_DSL.md](EXTENDING_WITH_DSL.md) for the full guide. Alternative to the YAML DSL you can manually update:
- `src/ml_switcheroo/semantics/standards_internal.py` and
- `src/ml_switcheroo/frameworks/*.py` (for torch, mlx, tensorflow, jax, etc.)
2. **Add a Framework**: Create a class inheriting `FrameworkAdapter` in `src/ml_switcheroo/frameworks/`.
3. **Add Logic**: Write a localized hook in `src/ml_switcheroo/plugins/` (e.g., for custom layer rewrites like
`MultiHeadAttention` packing).
See [EXTENDING.md](EXTENDING.md) for architectural details on Adapters and Plugins.
---
## License
[Apache-2.0](LICENSE)