ml_switcheroo.plugins.mlx_extras

Plugin for MLX Ecosystem Mapping.

Handles: 1. Compilation: @torch.compile -> @mx.compile. 2. Eager Evaluation: torch.cuda.synchronize() -> mx.eval(state). 3. Streams: torch.cuda.stream -> mx.stream(mx.gpu).

Functions

transform_compiler(→ libcst.CSTNode)

Hook: Maps JIT compilation decorators.

transform_synchronize(→ libcst.CSTNode)

Hook: Maps barrier synchronization.

Module Contents

ml_switcheroo.plugins.mlx_extras.transform_compiler(node: libcst.Decorator | libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode

Hook: Maps JIT compilation decorators.

Triggers: torch.compile (via requires_plugin: “mlx_compiler”).

Transformation:

Input: @torch.compile(fullgraph=True, dynamic=True) Output: @mx.compile (stripping incompatible kwargs).

Note: MLX’s compiler (mx.compile) is largely drop-in but does not support PyTorch specific flags like fullgraph or backend.

ml_switcheroo.plugins.mlx_extras.transform_synchronize(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode

Hook: Maps barrier synchronization.

Input: torch.cuda.synchronize() Output: mx.eval(state_vars) or mx.async_eval()?

MLX is lazy. Correct sync is mx.eval(tensors). Since synchronize() in Torch catches up everything globally, we map it to a comment or specific stream sync if possible.

Strict mapping: mx.eval() requires arguments (what to eval). If we can’t find arguments, we emit a comment.