ml_switcheroo.frameworks.torch¶
PyTorch Framework Adapter.
This module implements the FrameworkAdapter protocol for PyTorch. It provides:
Import Abstraction: Self-declared namespace mappings (e.g., torch.nn is NEURAL).
Semantic Definitions: Mappings loaded from definitions/torch.json via helper.
Discovery: Heuristics and logic for scanning the installed torch library.
IO & Device Support: Wires up serialization and device allocation.
Weight Migration: Implements logic to generate scripts for converting .pth checkpoints to/from NumPy format for interoperability.
Attributes¶
Classes¶
Adapter for PyTorch (Meta). |
Module Contents¶
- class ml_switcheroo.frameworks.torch.TorchAdapter[source]¶
Adapter for PyTorch (Meta).
Handles Source and Target translation rules for PyTorch, including torch.nn, torch.optim, and torch.func (vmap/grad).
- display_name: str = 'PyTorch'¶
- inherits_from: str | None = None¶
- ui_priority: int = 0¶
- property import_alias: Tuple[str, str]¶
Returns the primary root import alias (‘torch’, ‘torch’).
- Returns:
The module name and default alias.
- property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]¶
Defines the semantic roles of PyTorch namespaces.
- Returns:
Mapping of dot-path strings to configuration objects.
- property search_modules: List[str]¶
Modules to scan during scaffold or sync operations.
- Returns:
List of module names.
- property unsafe_submodules: Set[str]¶
Submodules that cause recursion depth errors or C-Extension crashes.
- Returns:
Set of module names to exclude from recursive scanning.
- property discovery_heuristics: Dict[str, List[str]]¶
Regex patterns to categorize discovered APIs.
- Returns:
Dictionary mapping category names to list of regex patterns.
- property supported_tiers: List[ml_switcheroo.enums.SemanticTier]¶
Returns the semantic tiers fully supported by this adapter.
- Returns:
List of supported tiers.
- property test_config: Dict[str, str]¶
Templates used by gen-tests to create physical verification files.
- Returns:
Dictionary of code templates.
- property harness_imports: List[str]¶
Imports required for harness initialization.
- Returns:
List of import statements.
- get_harness_init_code() str[source]¶
Returns helper code for initializing the harness.
- Returns:
Python source code string.
- get_to_numpy_code() str[source]¶
Returns code to convert Torch tensors to NumPy (detach/cpu check).
- Returns:
Python statement string.
- property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶
Defines how classes and functions are rewritten when targeting PyTorch.
- Returns:
Configuration object for structural rewriting.
- property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits¶
Capabilities flags. PyTorch uses imperative state and eager execution.
- Returns:
Configuration object for plugin logic.
- property rng_seed_methods: List[str]¶
Global seed setting methods detected as impure side-effects.
- Returns:
List of method names.
- property declared_magic_args: List[str]¶
Returns list of framework-specific magic arguments. Torch emits no magic args; all state is implicit.
- Returns:
Empty list.
- property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶
The definitive mapping of Abstract Operations to PyTorch APIs. Loaded dynamically from frameworks/definitions/torch.json.
- Returns:
Dictionary mapping operation abstract IDs to implementation details.
- get_device_syntax(device_type: str, device_index: str | None = None) str[source]¶
Generates code for device creation.
- Parameters:
device_type – The device type string (e.g. ‘cuda’, ‘cpu’).
device_index – The optional device index.
- Returns:
Code string for device creation.
- get_device_check_syntax() str[source]¶
Returns PyTorch syntax for checking CUDA availability.
- Returns:
Python expression string.
- get_rng_split_syntax(rng_var: str, key_var: str) str[source]¶
Returns syntax for splitting RNG state. PyTorch uses global state-based randomness, so explicit splitting is a no-op.
- Parameters:
rng_var – Variable name holding random state.
key_var – Variable name for the new key.
- Returns:
‘pass’ string (No-op).
- get_serialization_imports() List[str][source]¶
Returns imports required for IO operations.
- Returns:
List of import statements.
- get_serialization_syntax(op: str, file_arg: str, object_arg: str | None = None) str[source]¶
Generates save/load syntax.
- Parameters:
op – Operation name (‘save’ or ‘load’).
file_arg – Code string representing the file path.
object_arg – Code string representing the object to save (optional).
- Returns:
Python code string for the operation.
- get_weight_conversion_imports() List[str][source]¶
Returns imports required for the generated weight migration script logic.
- Returns:
List of import strings.
- get_weight_load_code(path_var: str) str[source]¶
Returns Python code to load a .pth file into a raw state dictionary. Handles both bare state dicts and Lightning-style checkpoints.
- Parameters:
path_var – Variable name containing the file path string.
- Returns:
Block of python code setting ‘raw_state’.
- get_tensor_to_numpy_expr(tensor_var: str) str[source]¶
Returns expression to convert a Torch tensor variable to a NumPy array. Includes detach and cpu calls for safety.
- Parameters:
tensor_var – Name of variable holding the torch tensor.
- Returns:
Expression string.
- get_weight_save_code(state_var: str, path_var: str) str[source]¶
Returns Python code to save the converted state dictionary back to .pth format. Converts NumPy arrays back to Torch tensors before saving.
- Parameters:
state_var – Variable name of the flat dictionary {key: numpy_array}.
path_var – Variable name of the output path.
- Returns:
Block of python code.
- get_doc_url(api_name: str) str | None[source]¶
Returns the official PyTorch documentation URL.
- Parameters:
api_name – The fully qualified API name.
- Returns:
URL string or None.
- get_tiered_examples() Dict[str, str][source]¶
Provides code snippets for “Wizard” or “Demo” usage.
- Returns:
Dictionary mapping tier IDs to code snippets.
- convert(data: Any) Any[source]¶
Converts input data (numpy, lists) into PyTorch Tensors for verification runners.
- Parameters:
data – Input data structure.
- Returns:
Converted PyTorch Tensor or original data if conversion fails.
- collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) List[ml_switcheroo.frameworks.base.GhostRef][source]¶
Implementation of the Ghost Protocol. Scans the locally installed PyTorch library for API definitions corresponding to the requested category (Loss, Layer, etc.).
- Parameters:
category – The standard category to search for.
- Returns:
List of discovered API references.