ml_switcheroo

ml-switcheroo Package.

A deterministic AST transpiler for converting Deep Learning models between frameworks (e.g., PyTorch -> JAX/Flax).

This package exposes the core compilation engine and configuration utilities for programmatic usage.

Usage

### Simple String Conversion

>>> import ml_switcheroo as mls
>>> code = "y = torch.abs(x)"
>>> result = mls.convert(code, source="torch", target="jax")
>>> print(result)
y = jax.numpy.abs(x)

### Advanced Usage (AST Engine)

>>> from ml_switcheroo import ASTEngine, RuntimeConfig
>>>
>>> config = RuntimeConfig(source_framework="torch", target_framework="jax", strict_mode=True)
>>> engine = ASTEngine(config=config)
>>> res = engine.run("y = torch.abs(x)")
>>>
>>> if res.success:
...     print(res.code)
... else:
...     print(f"Errors: {res.errors}")

Submodules

Attributes

__version__

Classes

RuntimeConfig

Global configuration container for the translation engine.

ASTEngine

The main compilation unit.

ConversionResult

Structured result of a single file conversion.

SemanticsManager

Central database for semantic mappings and configuration.

Functions

convert(→ str)

Transpiles a string of Python code from one framework to another.

Package Contents

class ml_switcheroo.RuntimeConfig(/, **data: Any)

Bases: pydantic.BaseModel

Global configuration container for the translation engine.

source_framework: str = None
target_framework: str = None
source_flavour: str | None = None
target_flavour: str | None = None
strict_mode: bool = None
plugin_settings: Dict[str, Any] = None
plugin_paths: List[pathlib.Path] = None
validation_report: pathlib.Path | None = None
classmethod validate_framework(v: str) str

Ensures the framework is registered in the system.

Parameters:

v (str) – The framework key to validate.

Returns:

The normalized (lowercase) framework key.

Return type:

str

Raises:

ValueError – If the framework is not found in the registry.

property effective_source: str

Resolves the specific framework key to use for source logic.

If a flavour (e.g. ‘flax_nnx’) is provided, it overrides the general framework (‘jax’).

Returns:

The active source framework key.

Return type:

str

property effective_target: str

Resolves the specific framework key to use for target logic.

Returns:

The active target framework key.

Return type:

str

parse_plugin_settings(schema: Type[T]) T

Validates the raw plugin settings dictionary against a specific Pydantic model.

Parameters:

schema (Type[T]) – The Pydantic model class defining expected settings.

Returns:

An instance of the schema model populated with runtime values.

Return type:

T

classmethod load(source: str | None = None, target: str | None = None, source_flavour: str | None = None, target_flavour: str | None = None, strict_mode: bool | None = None, plugin_settings: Dict[str, Any] | None = None, validation_report: pathlib.Path | None = None, search_path: pathlib.Path | None = None) RuntimeConfig

Loads configuration from pyproject.toml and overrides with CLI arguments.

Parameters:
  • source (Optional[str]) – Override for source framework.

  • target (Optional[str]) – Override for target framework.

  • source_flavour (Optional[str]) – Override for source flavour.

  • target_flavour (Optional[str]) – Override for target flavour.

  • strict_mode (Optional[bool]) – Override for strict mode setting.

  • plugin_settings (Optional[Dict]) – Additional CLI plugin settings.

  • validation_report (Optional[Path]) – Override for validation report path.

  • search_path (Optional[Path]) – Directory to start searching for TOML config.

Returns:

The fully resolved configuration object.

Return type:

RuntimeConfig

class ml_switcheroo.ASTEngine(semantics: ml_switcheroo.semantics.manager.SemanticsManager = None, config: ml_switcheroo.config.RuntimeConfig | None = None, source: str = 'torch', target: str = 'jax', strict_mode: bool = False, plugin_config: Dict[str, Any] | None = None)

The main compilation unit.

This class encapsulates the state and logic required to transpile a single unit of code. It manages the lifecycle of the LibCST tree and the invocation of visitor passes.

semantics
source
target
strict_mode
parse(code: str) libcst.Module

Parses source string into a LibCST Module.

Parameters:

code (str) – Python source code.

Returns:

The parsed Abstract Syntax Tree.

Return type:

cst.Module

Raises:

libcst.ParserSyntaxError – If the input code is invalid Python.

to_source(tree: libcst.Module) str

Converts CST back to source string.

Parameters:

tree (cst.Module) – The modified syntax tree.

Returns:

Generated Python code.

Return type:

str

run(code: str) ConversionResult

Executes the full transpilation pipeline.

Passes performed: 1. Parse. 2. Purity Scan (if targeting JAX-like frameworks). 3. Lifecycle Analysis (Init/Forward mismatch). 4. Dependency Scan (Checking 3rd party libs). 5. Pivot Rewrite (The main transformation). 6. Import Fixer (Injecting new imports, pruning old ones). 7. Structural Linting (Verifying output cleanliness).

Parameters:

code (str) – The input source string.

Returns:

Object containing transformed code and error logs.

Return type:

ConversionResult

class ml_switcheroo.ConversionResult(/, **data: Any)

Bases: pydantic.BaseModel

Structured result of a single file conversion.

code: str = None
errors: List[str] = None
success: bool = None
trace_events: List[Dict[str, Any]] = None
property has_errors: bool

Returns True if any errors or warnings were recorded during conversion.

Returns:

True if errors list is non-empty.

Return type:

bool

class ml_switcheroo.SemanticsManager

Central database for semantic mappings and configuration.

data: Dict[str, Dict]
import_data: Dict[str, Dict]
framework_configs: Dict[str, Dict]
test_templates: Dict[str, Dict]
get_all_rng_methods() Set[str]
resolve_variant(abstract_id: str, target_fw: str) Dict[str, Any] | None
load_validation_report(report_path: pathlib.Path) None
is_verified(abstract_id: str) bool
get_definition_by_id(abstract_id: str) Dict[str, Any] | None
get_definition(api_name: str) Tuple[str, Dict] | None
get_known_apis() Dict[str, Dict]
get_import_map(target_fw: str) Dict[str, Tuple[str, str | None, str | None]]
get_framework_config(framework: str) Dict[str, Any]
get_test_template(framework: str) Dict[str, str] | None
get_framework_aliases() Dict[str, Tuple[str, str]]
update_definition(abstract_id: str, new_data: Dict[str, Any]) None
ml_switcheroo.__version__ = '0.0.1'
ml_switcheroo.convert(code: str, source: str = 'torch', target: str = 'jax', strict: bool = False, plugin_settings: Dict[str, Any] | None = None, semantics: semantics.manager.SemanticsManager | None = None) str

Transpiles a string of Python code from one framework to another.

This is a high-level convenience wrapper around the ASTEngine. For file-based conversions or batch processing, consider using ml_switcheroo.cli or using ASTEngine directly.

Parameters:
  • code (str) – The source code to convert.

  • source (str) – The source framework key (e.g., “torch”, “jax”).

  • target (str) – The target framework key (e.g., “jax”, “tensorflow”).

  • strict (bool) – If True, the engine will return an error if an API cannot be mapped. If False (default), the original code is preserved wrapped in escape hatch comments.

  • plugin_settings (dict, optional) – Specific configuration flags passed to plugin hooks (e.g., {“rng_arg_name”: “seed”}).

  • semantics (SemanticsManager, optional) – An existing Knowledge Base instance. If None, a new one is initialized from disk.

Returns:

The transpiled source code.

Return type:

str

Raises:

ValueError – If the conversion fails (e.g. syntax errors or strict mode violations).