ml_switcheroo.plugins.mlx_optimizers¶

Plugin for MLX Optimizer translation.

Handles impedance mismatches between PyTorch and Apple MLX optimizers.

Functions¶

transform_mlx_optimizer_init(→ libcst.Call)

transform_mlx_optimizer_step(→ Union[libcst.Call, ...)

transform_mlx_zero_grad(→ libcst.CSTNode)

Module Contents¶

ml_switcheroo.plugins.mlx_optimizers.transform_mlx_optimizer_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call¶
ml_switcheroo.plugins.mlx_optimizers.transform_mlx_optimizer_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call | libcst.FlattenSentinel¶
ml_switcheroo.plugins.mlx_optimizers.transform_mlx_zero_grad(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode¶