ml_switcheroo¶
ml-switcheroo Package.
A deterministic AST transpiler for converting Deep Learning models between frameworks (e.g., PyTorch -> JAX/Flax).
This package exposes the core compilation engine and configuration utilities for programmatic usage.
Usage¶
### Simple String Conversion
>>> import ml_switcheroo as mls
>>> code = "y = torch.abs(x)"
>>> result = mls.convert(code, source="torch", target="jax")
>>> print(result)
y = jax.numpy.abs(x)
### Advanced Usage (AST Engine)
>>> from ml_switcheroo import ASTEngine, RuntimeConfig
>>>
>>> config = RuntimeConfig(source_framework="torch", target_framework="jax", strict_mode=True)
>>> engine = ASTEngine(config=config)
>>> res = engine.run("y = torch.abs(x)")
>>>
>>> if res.success:
... print(res.code)
... else:
... print(f"Errors: {res.errors}")
Submodules¶
- ml_switcheroo.__main__
- ml_switcheroo.analysis
- ml_switcheroo.cli
- ml_switcheroo.config
- ml_switcheroo.core
- ml_switcheroo.discovery
- ml_switcheroo.enums
- ml_switcheroo.frameworks
- ml_switcheroo.generated_tests
- ml_switcheroo.importers
- ml_switcheroo.plugins
- ml_switcheroo.semantics
- ml_switcheroo.testing
- ml_switcheroo.utils
Attributes¶
Classes¶
Global configuration container for the translation engine. |
|
The main compilation unit. |
|
Structured result of a single file conversion. |
|
Central database for semantic mappings and configuration. |
Functions¶
|
Transpiles a string of Python code from one framework to another. |
Package Contents¶
- class ml_switcheroo.RuntimeConfig(/, **data: Any)¶
Bases:
pydantic.BaseModelGlobal configuration container for the translation engine.
- source_framework: str = None¶
- target_framework: str = None¶
- source_flavour: str | None = None¶
- target_flavour: str | None = None¶
- strict_mode: bool = None¶
- plugin_settings: Dict[str, Any] = None¶
- plugin_paths: List[pathlib.Path] = None¶
- validation_report: pathlib.Path | None = None¶
- classmethod validate_framework(v: str) str¶
Ensures the framework is registered in the system.
- Parameters:
v (str) – The framework key to validate.
- Returns:
The normalized (lowercase) framework key.
- Return type:
str
- Raises:
ValueError – If the framework is not found in the registry.
- property effective_source: str¶
Resolves the specific framework key to use for source logic.
If a flavour (e.g. ‘flax_nnx’) is provided, it overrides the general framework (‘jax’).
- Returns:
The active source framework key.
- Return type:
str
- property effective_target: str¶
Resolves the specific framework key to use for target logic.
- Returns:
The active target framework key.
- Return type:
str
- parse_plugin_settings(schema: Type[T]) T¶
Validates the raw plugin settings dictionary against a specific Pydantic model.
- Parameters:
schema (Type[T]) – The Pydantic model class defining expected settings.
- Returns:
An instance of the schema model populated with runtime values.
- Return type:
T
- classmethod load(source: str | None = None, target: str | None = None, source_flavour: str | None = None, target_flavour: str | None = None, strict_mode: bool | None = None, plugin_settings: Dict[str, Any] | None = None, validation_report: pathlib.Path | None = None, search_path: pathlib.Path | None = None) RuntimeConfig¶
Loads configuration from pyproject.toml and overrides with CLI arguments.
- Parameters:
source (Optional[str]) – Override for source framework.
target (Optional[str]) – Override for target framework.
source_flavour (Optional[str]) – Override for source flavour.
target_flavour (Optional[str]) – Override for target flavour.
strict_mode (Optional[bool]) – Override for strict mode setting.
plugin_settings (Optional[Dict]) – Additional CLI plugin settings.
validation_report (Optional[Path]) – Override for validation report path.
search_path (Optional[Path]) – Directory to start searching for TOML config.
- Returns:
The fully resolved configuration object.
- Return type:
- class ml_switcheroo.ASTEngine(semantics: ml_switcheroo.semantics.manager.SemanticsManager = None, config: ml_switcheroo.config.RuntimeConfig | None = None, source: str = 'torch', target: str = 'jax', strict_mode: bool = False, plugin_config: Dict[str, Any] | None = None)¶
The main compilation unit.
This class encapsulates the state and logic required to transpile a single unit of code. It manages the lifecycle of the LibCST tree and the invocation of visitor passes.
- semantics¶
- source¶
- target¶
- strict_mode¶
- parse(code: str) libcst.Module¶
Parses source string into a LibCST Module.
- Parameters:
code (str) – Python source code.
- Returns:
The parsed Abstract Syntax Tree.
- Return type:
cst.Module
- Raises:
libcst.ParserSyntaxError – If the input code is invalid Python.
- to_source(tree: libcst.Module) str¶
Converts CST back to source string.
- Parameters:
tree (cst.Module) – The modified syntax tree.
- Returns:
Generated Python code.
- Return type:
str
- run(code: str) ConversionResult¶
Executes the full transpilation pipeline.
Passes performed: 1. Parse. 2. Purity Scan (if targeting JAX-like frameworks). 3. Lifecycle Analysis (Init/Forward mismatch). 4. Dependency Scan (Checking 3rd party libs). 5. Pivot Rewrite (The main transformation). 6. Import Fixer (Injecting new imports, pruning old ones). 7. Structural Linting (Verifying output cleanliness).
- Parameters:
code (str) – The input source string.
- Returns:
Object containing transformed code and error logs.
- Return type:
- class ml_switcheroo.ConversionResult(/, **data: Any)¶
Bases:
pydantic.BaseModelStructured result of a single file conversion.
- code: str = None¶
- errors: List[str] = None¶
- success: bool = None¶
- trace_events: List[Dict[str, Any]] = None¶
- property has_errors: bool¶
Returns True if any errors or warnings were recorded during conversion.
- Returns:
True if errors list is non-empty.
- Return type:
bool
- class ml_switcheroo.SemanticsManager¶
Central database for semantic mappings and configuration.
- data: Dict[str, Dict]¶
- import_data: Dict[str, Dict]¶
- framework_configs: Dict[str, Dict]¶
- test_templates: Dict[str, Dict]¶
- get_all_rng_methods() Set[str]¶
- resolve_variant(abstract_id: str, target_fw: str) Dict[str, Any] | None¶
- load_validation_report(report_path: pathlib.Path) None¶
- is_verified(abstract_id: str) bool¶
- get_definition_by_id(abstract_id: str) Dict[str, Any] | None¶
- get_definition(api_name: str) Tuple[str, Dict] | None¶
- get_known_apis() Dict[str, Dict]¶
- get_import_map(target_fw: str) Dict[str, Tuple[str, str | None, str | None]]¶
- get_framework_config(framework: str) Dict[str, Any]¶
- get_test_template(framework: str) Dict[str, str] | None¶
- get_framework_aliases() Dict[str, Tuple[str, str]]¶
- update_definition(abstract_id: str, new_data: Dict[str, Any]) None¶
- ml_switcheroo.__version__ = '0.0.1'¶
- ml_switcheroo.convert(code: str, source: str = 'torch', target: str = 'jax', strict: bool = False, plugin_settings: Dict[str, Any] | None = None, semantics: semantics.manager.SemanticsManager | None = None) str¶
Transpiles a string of Python code from one framework to another.
This is a high-level convenience wrapper around the ASTEngine. For file-based conversions or batch processing, consider using ml_switcheroo.cli or using ASTEngine directly.
- Parameters:
code (str) – The source code to convert.
source (str) – The source framework key (e.g., “torch”, “jax”).
target (str) – The target framework key (e.g., “jax”, “tensorflow”).
strict (bool) – If True, the engine will return an error if an API cannot be mapped. If False (default), the original code is preserved wrapped in escape hatch comments.
plugin_settings (dict, optional) – Specific configuration flags passed to plugin hooks (e.g., {“rng_arg_name”: “seed”}).
semantics (SemanticsManager, optional) – An existing Knowledge Base instance. If None, a new one is initialized from disk.
- Returns:
The transpiled source code.
- Return type:
str
- Raises:
ValueError – If the conversion fails (e.g. syntax errors or strict mode violations).