ml_switcheroo.plugins.mlx_extras ================================ .. py:module:: ml_switcheroo.plugins.mlx_extras .. autoapi-nested-parse:: Plugin for MLX Ecosystem Mapping. Handles: 1. Compilation: `@torch.compile` -> `@mx.compile`. 2. Eager Evaluation: `torch.cuda.synchronize()` -> `mx.eval(state)`. 3. Streams: `torch.cuda.stream` -> `mx.stream(mx.gpu)`. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.mlx_extras.transform_compiler ml_switcheroo.plugins.mlx_extras.transform_synchronize Module Contents --------------- .. py:function:: transform_compiler(node: Union[libcst.Decorator, libcst.Call], ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.CSTNode Hook: Maps JIT compilation decorators. Triggers: `torch.compile` (via `requires_plugin: "mlx_compiler"`). Transformation: Input: `@torch.compile(fullgraph=True, dynamic=True)` Output: `@mx.compile` (stripping incompatible kwargs). Note: MLX's compiler (`mx.compile`) is largely drop-in but does not support PyTorch specific flags like `fullgraph` or `backend`. .. py:function:: transform_synchronize(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.CSTNode Hook: Maps barrier synchronization. Input: `torch.cuda.synchronize()` Output: `mx.eval(state_vars)` or `mx.async_eval()`? MLX is lazy. Correct sync is `mx.eval(tensors)`. Since `synchronize()` in Torch catches up everything globally, we map it to a comment or specific stream sync if possible. Strict mapping: `mx.eval()` requires arguments (what to eval). If we can't find arguments, we emit a comment.