ml_switcheroo.plugins.mlx_optimizers¶

Plugin for MLX Optimizer translation.

Handles impedance mismatches for Functional Optimizers.

Functions¶

transform_mlx_optimizer_init(→ libcst.Call)

Hook: Transforms Optimizer Constructor.

transform_mlx_optimizer_step(→ Union[libcst.Call, ...)

Hook: Transforms optimizer.step() into an EscapeHatch pattern.

transform_mlx_zero_grad(→ libcst.CSTNode)

Hook: Transforms optimizer.zero_grad() into None (No-Op).

Module Contents¶

ml_switcheroo.plugins.mlx_optimizers.transform_mlx_optimizer_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call[source]¶

Hook: Transforms Optimizer Constructor.

  1. Renames API based on context lookup or dynamic class construction.

  2. Strips parameter argument (Arg 0).

  3. Renames lr -> learning_rate.

Parameters:
  • node (cst.Call) – Original CST call.

  • ctx (HookContext) – Hook execution context.

Returns:

Transformed optimizer initialization.

Return type:

cst.Call

ml_switcheroo.plugins.mlx_optimizers.transform_mlx_optimizer_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call | libcst.FlattenSentinel[source]¶

Hook: Transforms optimizer.step() into an EscapeHatch pattern. Functional optimizers (like MLX/Optax) require explicit update calls opt.update(model, state).

Parameters:
  • node (cst.Call) – Original CST call.

  • ctx (HookContext) – Hook execution context.

Returns:

The node wrapped in an EscapeHatch.

Return type:

Union[cst.Call, cst.FlattenSentinel]

ml_switcheroo.plugins.mlx_optimizers.transform_mlx_zero_grad(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode[source]¶

Hook: Transforms optimizer.zero_grad() into None (No-Op).

Parameters:
  • node (cst.Call) – Original CST call.

  • ctx (HookContext) – Non-used hook context.

Returns:

A ‘None’ node execution.

Return type:

cst.CSTNode