ml_switcheroo ============= .. py:module:: ml_switcheroo .. autoapi-nested-parse:: ml-switcheroo Package. A deterministic AST transpiler for converting Deep Learning models between frameworks (e.g., PyTorch -> JAX/Flax). This package exposes the core compilation engine and configuration utilities for programmatic usage. Usage ----- ### Simple String Conversion >>> import ml_switcheroo as mls >>> code = "y = torch.abs(x)" >>> result = mls.convert(code, source="torch", target="jax") >>> print(result) y = jax.numpy.abs(x) ### Advanced Usage (AST Engine) >>> from ml_switcheroo import ASTEngine, RuntimeConfig >>> >>> config = RuntimeConfig(source_framework="torch", target_framework="jax", strict_mode=True) >>> engine = ASTEngine(config=config) >>> res = engine.run("y = torch.abs(x)") >>> >>> if res.success: ... print(res.code) ... else: ... print(f"Errors: {res.errors}") Submodules ---------- .. toctree:: :maxdepth: 1 /api/ml_switcheroo/__main__/index /api/ml_switcheroo/analysis/index /api/ml_switcheroo/cli/index /api/ml_switcheroo/config/index /api/ml_switcheroo/core/index /api/ml_switcheroo/discovery/index /api/ml_switcheroo/enums/index /api/ml_switcheroo/frameworks/index /api/ml_switcheroo/generated_tests/index /api/ml_switcheroo/importers/index /api/ml_switcheroo/plugins/index /api/ml_switcheroo/semantics/index /api/ml_switcheroo/testing/index /api/ml_switcheroo/utils/index Attributes ---------- .. autoapisummary:: ml_switcheroo.__version__ Classes ------- .. autoapisummary:: ml_switcheroo.RuntimeConfig ml_switcheroo.ASTEngine ml_switcheroo.ConversionResult ml_switcheroo.SemanticsManager Functions --------- .. autoapisummary:: ml_switcheroo.convert Package Contents ---------------- .. py:class:: RuntimeConfig(/, **data: Any) Bases: :py:obj:`pydantic.BaseModel` Global configuration container for the translation engine. .. py:attribute:: source_framework :type: str :value: None .. py:attribute:: target_framework :type: str :value: None .. py:attribute:: source_flavour :type: Optional[str] :value: None .. py:attribute:: target_flavour :type: Optional[str] :value: None .. py:attribute:: strict_mode :type: bool :value: None .. py:attribute:: plugin_settings :type: Dict[str, Any] :value: None .. py:attribute:: plugin_paths :type: List[pathlib.Path] :value: None .. py:attribute:: validation_report :type: Optional[pathlib.Path] :value: None .. py:method:: validate_framework(v: str) -> str :classmethod: Ensures the framework is registered in the system. :param v: The framework key to validate. :type v: str :returns: The normalized (lowercase) framework key. :rtype: str :raises ValueError: If the framework is not found in the registry. .. py:property:: effective_source :type: str Resolves the specific framework key to use for source logic. If a flavour (e.g. 'flax_nnx') is provided, it overrides the general framework ('jax'). :returns: The active source framework key. :rtype: str .. py:property:: effective_target :type: str Resolves the specific framework key to use for target logic. :returns: The active target framework key. :rtype: str .. py:method:: parse_plugin_settings(schema: Type[T]) -> T Validates the raw plugin settings dictionary against a specific Pydantic model. :param schema: The Pydantic model class defining expected settings. :type schema: Type[T] :returns: An instance of the schema model populated with runtime values. :rtype: T .. py:method:: load(source: Optional[str] = None, target: Optional[str] = None, source_flavour: Optional[str] = None, target_flavour: Optional[str] = None, strict_mode: Optional[bool] = None, plugin_settings: Optional[Dict[str, Any]] = None, validation_report: Optional[pathlib.Path] = None, search_path: Optional[pathlib.Path] = None) -> RuntimeConfig :classmethod: Loads configuration from pyproject.toml and overrides with CLI arguments. :param source: Override for source framework. :type source: Optional[str] :param target: Override for target framework. :type target: Optional[str] :param source_flavour: Override for source flavour. :type source_flavour: Optional[str] :param target_flavour: Override for target flavour. :type target_flavour: Optional[str] :param strict_mode: Override for strict mode setting. :type strict_mode: Optional[bool] :param plugin_settings: Additional CLI plugin settings. :type plugin_settings: Optional[Dict] :param validation_report: Override for validation report path. :type validation_report: Optional[Path] :param search_path: Directory to start searching for TOML config. :type search_path: Optional[Path] :returns: The fully resolved configuration object. :rtype: RuntimeConfig .. py:class:: ASTEngine(semantics: ml_switcheroo.semantics.manager.SemanticsManager = None, config: Optional[ml_switcheroo.config.RuntimeConfig] = None, source: str = 'torch', target: str = 'jax', strict_mode: bool = False, plugin_config: Optional[Dict[str, Any]] = None) The main compilation unit. This class encapsulates the state and logic required to transpile a single unit of code. It manages the lifecycle of the LibCST tree and the invocation of visitor passes. .. py:attribute:: semantics .. py:attribute:: source .. py:attribute:: target .. py:attribute:: strict_mode .. py:method:: parse(code: str) -> libcst.Module Parses source string into a LibCST Module. :param code: Python source code. :type code: str :returns: The parsed Abstract Syntax Tree. :rtype: cst.Module :raises libcst.ParserSyntaxError: If the input code is invalid Python. .. py:method:: to_source(tree: libcst.Module) -> str Converts CST back to source string. :param tree: The modified syntax tree. :type tree: cst.Module :returns: Generated Python code. :rtype: str .. py:method:: run(code: str) -> ConversionResult Executes the full transpilation pipeline. Passes performed: 1. Parse. 2. Purity Scan (if targeting JAX-like frameworks). 3. Lifecycle Analysis (Init/Forward mismatch). 4. Dependency Scan (Checking 3rd party libs). 5. Pivot Rewrite (The main transformation). 6. Import Fixer (Injecting new imports, pruning old ones). 7. Structural Linting (Verifying output cleanliness). :param code: The input source string. :type code: str :returns: Object containing transformed code and error logs. :rtype: ConversionResult .. py:class:: ConversionResult(/, **data: Any) Bases: :py:obj:`pydantic.BaseModel` Structured result of a single file conversion. .. py:attribute:: code :type: str :value: None .. py:attribute:: errors :type: List[str] :value: None .. py:attribute:: success :type: bool :value: None .. py:attribute:: trace_events :type: List[Dict[str, Any]] :value: None .. py:property:: has_errors :type: bool Returns True if any errors or warnings were recorded during conversion. :returns: True if errors list is non-empty. :rtype: bool .. py:class:: SemanticsManager Central database for semantic mappings and configuration. .. py:attribute:: data :type: Dict[str, Dict] .. py:attribute:: import_data :type: Dict[str, Dict] .. py:attribute:: framework_configs :type: Dict[str, Dict] .. py:attribute:: test_templates :type: Dict[str, Dict] .. py:method:: get_all_rng_methods() -> Set[str] .. py:method:: resolve_variant(abstract_id: str, target_fw: str) -> Optional[Dict[str, Any]] .. py:method:: load_validation_report(report_path: pathlib.Path) -> None .. py:method:: is_verified(abstract_id: str) -> bool .. py:method:: get_definition_by_id(abstract_id: str) -> Optional[Dict[str, Any]] .. py:method:: get_definition(api_name: str) -> Optional[Tuple[str, Dict]] .. py:method:: get_known_apis() -> Dict[str, Dict] .. py:method:: get_import_map(target_fw: str) -> Dict[str, Tuple[str, Optional[str], Optional[str]]] .. py:method:: get_framework_config(framework: str) -> Dict[str, Any] .. py:method:: get_test_template(framework: str) -> Optional[Dict[str, str]] .. py:method:: get_framework_aliases() -> Dict[str, Tuple[str, str]] .. py:method:: update_definition(abstract_id: str, new_data: Dict[str, Any]) -> None .. py:data:: __version__ :value: '0.0.1' .. py:function:: convert(code: str, source: str = 'torch', target: str = 'jax', strict: bool = False, plugin_settings: Optional[Dict[str, Any]] = None, semantics: Optional[semantics.manager.SemanticsManager] = None) -> str Transpiles a string of Python code from one framework to another. This is a high-level convenience wrapper around the `ASTEngine`. For file-based conversions or batch processing, consider using `ml_switcheroo.cli` or using `ASTEngine` directly. :param code: The source code to convert. :type code: str :param source: The source framework key (e.g., "torch", "jax"). :type source: str :param target: The target framework key (e.g., "jax", "tensorflow"). :type target: str :param strict: If True, the engine will return an error if an API cannot be mapped. If False (default), the original code is preserved wrapped in escape hatch comments. :type strict: bool :param plugin_settings: Specific configuration flags passed to plugin hooks (e.g., `{"rng_arg_name": "seed"}`). :type plugin_settings: dict, optional :param semantics: An existing Knowledge Base instance. If None, a new one is initialized from disk. :type semantics: SemanticsManager, optional :returns: The transpiled source code. :rtype: str :raises ValueError: If the conversion fails (e.g. syntax errors or strict mode violations).