ml_switcheroo.frameworks.mlx

Apple MLX Framework Adapter.

This module provides the adapter for Apple’s MLX array framework. It supports: 1. Unified Memory math: Mapping mlx.core operations. 2. Neural Networks: Mapping mlx.nn layers and containers. 3. Discovery: Runtime introspection of the MLX API surface. 4. Types: Mapping Abstract Types to mlx.core dtypes (e.g. mx.float32). 5. Casting: Generic casting plugin integration via .astype(). 6. Weight Migration: Loading/saving .npz or .safetensors files (via stubs/core).

Definitions are loaded from frameworks/definitions/mlx.json.

Classes

MLXAdapter

Adapter for Apple MLX (Silicon-optimized tensor framework).

Module Contents

class ml_switcheroo.frameworks.mlx.MLXAdapter[source]

Adapter for Apple MLX (Silicon-optimized tensor framework).

display_name: str = 'Apple MLX'
inherits_from: str | None = None
ui_priority: int = 50
property search_modules: List[str]

Returns list of MLX submodules to scan during discovery.

Returns:

Module list.

Return type:

List[str]

property unsafe_submodules: Set[str]

Submodules safe to avoid during recursion.

Returns:

Default empty set.

Return type:

Set[str]

property import_alias: Tuple[str, str]

import mlx.core as mx.

Returns:

(“mlx.core”, “mx”).

Return type:

Tuple[str, str]

Type:

Default alias for core array operations

property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]

Self-declaration of namespaces.

Returns:

Namespace map.

Return type:

Dict[str, ImportConfig]

property discovery_heuristics: Dict[str, List[str]]

Regex patterns for categorizing discovered APIs into Tiers.

Returns:

Heuristics map.

Return type:

Dict[str, List[str]]

property test_config: Dict[str, str]

Templates for generating physical verification tests.

Returns:

Templates.

Return type:

Dict[str, str]

property harness_imports: List[str]

Imports for harness.

Returns:

Empty list.

Return type:

List[str]

get_harness_init_code() str[source]

Initialization code.

Returns:

Empty string.

Return type:

str

get_to_numpy_code() str[source]

Returns code to convert MLX arrays (which have .tolist()) to NumPy.

Returns:

Python logic for conversion.

Return type:

str

property supported_tiers: List[ml_switcheroo.enums.SemanticTier]

Returns supported semantic tiers (Array, Neural, Extras).

Returns:

Supported Tiers.

Return type:

List[SemanticTier]

property declared_magic_args: List[str]

Implicit RNG arguments.

Returns:

Empty.

Return type:

List[str]

property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits

Defines structural rewriting rules (Classes, Methods, Init).

Updated to strip ‘rngs’ argument coming from Flax NNX, as MLX handles initialization statefully/eagerly.

Returns:

Config object.

Return type:

StructuralTraits

property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits

Plugin behavior configuration.

Returns:

Config object.

Return type:

PluginTraits

property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]

Static definitions for MLX mappings. Loaded dynamically from frameworks/definitions/mlx.json.

Returns:

Definitions map.

Return type:

Dict[str, StandardMap]

property rng_seed_methods: List[str]

Returns list of global seed setters.

Returns:

Method names.

Return type:

List[str]

collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) List[ml_switcheroo.core.ghost.GhostRef][source]

Performs runtime introspection to discover available MLX APIs.

Parameters:

category (StandardCategory) – Category to scan.

Returns:

Found items.

Return type:

List[GhostRef]

convert(data: Any) Any[source]

Converts input data (NumPy/List) to MLX Tensor for verification.

Parameters:

data (Any) – Input.

Returns:

MLX Array or original.

Return type:

Any

get_tiered_examples() Dict[str, str][source]

Returns MLX idiomatic examples used for validity testing.

Returns:

Example maps.

Return type:

Dict[str, str]

get_device_syntax(device_type: str, device_index: str | None = None) str[source]

Returns device constructor syntax.

Parameters:
  • device_type – Device description.

  • device_index – Device index.

Returns:

Generated code.

Return type:

str

get_device_check_syntax() str[source]

Check if default device is GPU. Note: MLX Unified Memory doesn’t have strict ‘is_available’ but we check backend.

Returns:

Code string.

Return type:

str

get_rng_split_syntax(rng_var: str, key_var: str) str[source]

MLX usually uses implicit state, but if explicit mode is requested, return ‘pass’ as split logic differs significantly.

Returns:

“pass”.

Return type:

str

get_serialization_imports() List[str][source]

Returns imports for serialization.

Returns:

Imports.

Return type:

List[str]

get_serialization_syntax(op: str, file_arg: str, object_arg: str | None = None) str[source]

Returns save/load syntax.

Parameters:
  • op – ‘save’ or ‘load’.

  • file_arg – Target file path.

  • object_arg – Object name.

Returns:

Code string.

Return type:

str

get_weight_conversion_imports() List[str][source]

Returns imports needed for weight scripts.

get_weight_load_code(path_var: str) str[source]

Loads weights using mx.load (npz/safetensors) into a raw dictionary.

get_tensor_to_numpy_expr(tensor_var: str) str[source]

Converts MLX array to numpy.

get_weight_save_code(state_var: str, path_var: str) str[source]

Saves dictionary of arrays to .npz or .safetensors.

apply_wiring(snapshot: Dict[str, Any]) None[source]

Applies manual wiring for MLX. Overrides/Patches snapshot items that cannot be statically defined.

Parameters:

snapshot – Snapshotdict.

get_doc_url(api_name: str) str | None[source]

Generates documentation URL for MLX APIs using autosummary pattern.

Parameters:

api_name – Fully qualified API string.

Returns:

URL.

Return type:

Optional[str]