ml_switcheroo.frameworks.mlx¶
Apple MLX Framework Adapter.
This module provides the adapter for Apple’s MLX array framework.
It supports:
1. Unified Memory math: Mapping mlx.core operations.
2. Neural Networks: Mapping mlx.nn layers and containers.
3. Discovery: Runtime introspection of the MLX API surface.
4. Types: Mapping Abstract Types to mlx.core dtypes (e.g. mx.float32).
5. Casting: Generic casting plugin integration via .astype().
6. Weight Migration: Loading/saving .npz or .safetensors files (via stubs/core).
Definitions are loaded from frameworks/definitions/mlx.json.
Classes¶
Adapter for Apple MLX (Silicon-optimized tensor framework). |
Module Contents¶
- class ml_switcheroo.frameworks.mlx.MLXAdapter[source]¶
Adapter for Apple MLX (Silicon-optimized tensor framework).
- display_name: str = 'Apple MLX'¶
- inherits_from: str | None = None¶
- ui_priority: int = 50¶
- property search_modules: List[str]¶
Returns list of MLX submodules to scan during discovery.
- Returns:
Module list.
- Return type:
List[str]
- property unsafe_submodules: Set[str]¶
Submodules safe to avoid during recursion.
- Returns:
Default empty set.
- Return type:
Set[str]
- property import_alias: Tuple[str, str]¶
import mlx.core as mx.- Returns:
(“mlx.core”, “mx”).
- Return type:
Tuple[str, str]
- Type:
Default alias for core array operations
- property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]¶
Self-declaration of namespaces.
- Returns:
Namespace map.
- Return type:
Dict[str, ImportConfig]
- property discovery_heuristics: Dict[str, List[str]]¶
Regex patterns for categorizing discovered APIs into Tiers.
- Returns:
Heuristics map.
- Return type:
Dict[str, List[str]]
- property test_config: Dict[str, str]¶
Templates for generating physical verification tests.
- Returns:
Templates.
- Return type:
Dict[str, str]
- property harness_imports: List[str]¶
Imports for harness.
- Returns:
Empty list.
- Return type:
List[str]
- get_to_numpy_code() str[source]¶
Returns code to convert MLX arrays (which have .tolist()) to NumPy.
- Returns:
Python logic for conversion.
- Return type:
str
- property supported_tiers: List[ml_switcheroo.enums.SemanticTier]¶
Returns supported semantic tiers (Array, Neural, Extras).
- Returns:
Supported Tiers.
- Return type:
List[SemanticTier]
- property declared_magic_args: List[str]¶
Implicit RNG arguments.
- Returns:
Empty.
- Return type:
List[str]
- property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶
Defines structural rewriting rules (Classes, Methods, Init).
Updated to strip ‘rngs’ argument coming from Flax NNX, as MLX handles initialization statefully/eagerly.
- Returns:
Config object.
- Return type:
- property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits¶
Plugin behavior configuration.
- Returns:
Config object.
- Return type:
- property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶
Static definitions for MLX mappings. Loaded dynamically from frameworks/definitions/mlx.json.
- Returns:
Definitions map.
- Return type:
Dict[str, StandardMap]
- property rng_seed_methods: List[str]¶
Returns list of global seed setters.
- Returns:
Method names.
- Return type:
List[str]
- collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) List[ml_switcheroo.core.ghost.GhostRef][source]¶
Performs runtime introspection to discover available MLX APIs.
- Parameters:
category (StandardCategory) – Category to scan.
- Returns:
Found items.
- Return type:
List[GhostRef]
- convert(data: Any) Any[source]¶
Converts input data (NumPy/List) to MLX Tensor for verification.
- Parameters:
data (Any) – Input.
- Returns:
MLX Array or original.
- Return type:
Any
- get_tiered_examples() Dict[str, str][source]¶
Returns MLX idiomatic examples used for validity testing.
- Returns:
Example maps.
- Return type:
Dict[str, str]
- get_device_syntax(device_type: str, device_index: str | None = None) str[source]¶
Returns device constructor syntax.
- Parameters:
device_type – Device description.
device_index – Device index.
- Returns:
Generated code.
- Return type:
str
- get_device_check_syntax() str[source]¶
Check if default device is GPU. Note: MLX Unified Memory doesn’t have strict ‘is_available’ but we check backend.
- Returns:
Code string.
- Return type:
str
- get_rng_split_syntax(rng_var: str, key_var: str) str[source]¶
MLX usually uses implicit state, but if explicit mode is requested, return ‘pass’ as split logic differs significantly.
- Returns:
“pass”.
- Return type:
str
- get_serialization_imports() List[str][source]¶
Returns imports for serialization.
- Returns:
Imports.
- Return type:
List[str]
- get_serialization_syntax(op: str, file_arg: str, object_arg: str | None = None) str[source]¶
Returns save/load syntax.
- Parameters:
op – ‘save’ or ‘load’.
file_arg – Target file path.
object_arg – Object name.
- Returns:
Code string.
- Return type:
str
- get_weight_load_code(path_var: str) str[source]¶
Loads weights using mx.load (npz/safetensors) into a raw dictionary.
- get_weight_save_code(state_var: str, path_var: str) str[source]¶
Saves dictionary of arrays to .npz or .safetensors.