ml_switcheroo.frameworks.torch

PyTorch Framework Adapter.

This module implements the FrameworkAdapter protocol for PyTorch. It provides:

  1. Import Abstraction: Self-declared namespace mappings (e.g., torch.nn is NEURAL).

  2. Semantic Definitions: Mappings loaded from definitions/torch.json via helper.

  3. Discovery: Heuristics and logic for scanning the installed torch library.

  4. IO & Device Support: Wires up serialization and device allocation.

  5. Weight Migration: Implements logic to generate scripts for converting .pth checkpoints to/from NumPy format for interoperability.

Attributes

torch

Classes

TorchAdapter

Adapter for PyTorch (Meta).

Module Contents

ml_switcheroo.frameworks.torch.torch = None[source]
class ml_switcheroo.frameworks.torch.TorchAdapter[source]

Adapter for PyTorch (Meta).

Handles Source and Target translation rules for PyTorch, including torch.nn, torch.optim, and torch.func (vmap/grad).

display_name: str = 'PyTorch'
inherits_from: str | None = None
ui_priority: int = 0
property import_alias: Tuple[str, str]

Returns the primary root import alias (‘torch’, ‘torch’).

Returns:

The module name and default alias.

property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]

Defines the semantic roles of PyTorch namespaces.

Returns:

Mapping of dot-path strings to configuration objects.

property search_modules: List[str]

Modules to scan during scaffold or sync operations.

Returns:

List of module names.

property unsafe_submodules: Set[str]

Submodules that cause recursion depth errors or C-Extension crashes.

Returns:

Set of module names to exclude from recursive scanning.

property discovery_heuristics: Dict[str, List[str]]

Regex patterns to categorize discovered APIs.

Returns:

Dictionary mapping category names to list of regex patterns.

property supported_tiers: List[ml_switcheroo.enums.SemanticTier]

Returns the semantic tiers fully supported by this adapter.

Returns:

List of supported tiers.

property test_config: Dict[str, str]

Templates used by gen-tests to create physical verification files.

Returns:

Dictionary of code templates.

property harness_imports: List[str]

Imports required for harness initialization.

Returns:

List of import statements.

get_harness_init_code() str[source]

Returns helper code for initializing the harness.

Returns:

Python source code string.

get_to_numpy_code() str[source]

Returns code to convert Torch tensors to NumPy (detach/cpu check).

Returns:

Python statement string.

property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits

Defines how classes and functions are rewritten when targeting PyTorch.

Returns:

Configuration object for structural rewriting.

property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits

Capabilities flags. PyTorch uses imperative state and eager execution.

Returns:

Configuration object for plugin logic.

property rng_seed_methods: List[str]

Global seed setting methods detected as impure side-effects.

Returns:

List of method names.

property declared_magic_args: List[str]

Returns list of framework-specific magic arguments. Torch emits no magic args; all state is implicit.

Returns:

Empty list.

property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]

The definitive mapping of Abstract Operations to PyTorch APIs. Loaded dynamically from frameworks/definitions/torch.json.

Returns:

Dictionary mapping operation abstract IDs to implementation details.

get_device_syntax(device_type: str, device_index: str | None = None) str[source]

Generates code for device creation.

Parameters:
  • device_type – The device type string (e.g. ‘cuda’, ‘cpu’).

  • device_index – The optional device index.

Returns:

Code string for device creation.

get_device_check_syntax() str[source]

Returns PyTorch syntax for checking CUDA availability.

Returns:

Python expression string.

get_rng_split_syntax(rng_var: str, key_var: str) str[source]

Returns syntax for splitting RNG state. PyTorch uses global state-based randomness, so explicit splitting is a no-op.

Parameters:
  • rng_var – Variable name holding random state.

  • key_var – Variable name for the new key.

Returns:

‘pass’ string (No-op).

get_serialization_imports() List[str][source]

Returns imports required for IO operations.

Returns:

List of import statements.

get_serialization_syntax(op: str, file_arg: str, object_arg: str | None = None) str[source]

Generates save/load syntax.

Parameters:
  • op – Operation name (‘save’ or ‘load’).

  • file_arg – Code string representing the file path.

  • object_arg – Code string representing the object to save (optional).

Returns:

Python code string for the operation.

get_weight_conversion_imports() List[str][source]

Returns imports required for the generated weight migration script logic.

Returns:

List of import strings.

get_weight_load_code(path_var: str) str[source]

Returns Python code to load a .pth file into a raw state dictionary. Handles both bare state dicts and Lightning-style checkpoints.

Parameters:

path_var – Variable name containing the file path string.

Returns:

Block of python code setting ‘raw_state’.

get_tensor_to_numpy_expr(tensor_var: str) str[source]

Returns expression to convert a Torch tensor variable to a NumPy array. Includes detach and cpu calls for safety.

Parameters:

tensor_var – Name of variable holding the torch tensor.

Returns:

Expression string.

get_weight_save_code(state_var: str, path_var: str) str[source]

Returns Python code to save the converted state dictionary back to .pth format. Converts NumPy arrays back to Torch tensors before saving.

Parameters:
  • state_var – Variable name of the flat dictionary {key: numpy_array}.

  • path_var – Variable name of the output path.

Returns:

Block of python code.

get_doc_url(api_name: str) str | None[source]

Returns the official PyTorch documentation URL.

Parameters:

api_name – The fully qualified API name.

Returns:

URL string or None.

get_tiered_examples() Dict[str, str][source]

Provides code snippets for “Wizard” or “Demo” usage.

Returns:

Dictionary mapping tier IDs to code snippets.

convert(data: Any) Any[source]

Converts input data (numpy, lists) into PyTorch Tensors for verification runners.

Parameters:

data – Input data structure.

Returns:

Converted PyTorch Tensor or original data if conversion fails.

collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) List[ml_switcheroo.frameworks.base.GhostRef][source]

Implementation of the Ghost Protocol. Scans the locally installed PyTorch library for API definitions corresponding to the requested category (Loss, Layer, etc.).

Parameters:

category – The standard category to search for.

Returns:

List of discovered API references.

apply_wiring(snapshot: Dict[str, Any]) None[source]

Apply manual patches to the standard mappings if necessary. Used to inject complex behaviors not captured by simple API scanning.

Parameters:

snapshot – The snapshot dictionary to update in-place.