ml_switcheroo.plugins.mlx_optimizers¶
Plugin for MLX Optimizer translation.
Handles impedance mismatches for Functional Optimizers.
Functions¶
|
Hook: Transforms Optimizer Constructor. |
|
Hook: Transforms optimizer.step() into an EscapeHatch pattern. |
|
Hook: Transforms optimizer.zero_grad() into None (No-Op). |
Module Contents¶
- ml_switcheroo.plugins.mlx_optimizers.transform_mlx_optimizer_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Hook: Transforms Optimizer Constructor.
Renames API based on context lookup or dynamic class construction.
Strips parameter argument (Arg 0).
Renames lr -> learning_rate.
- Parameters:
node (cst.Call) – Original CST call.
ctx (HookContext) – Hook execution context.
- Returns:
Transformed optimizer initialization.
- Return type:
cst.Call
- ml_switcheroo.plugins.mlx_optimizers.transform_mlx_optimizer_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call | libcst.FlattenSentinel[source]¶
Hook: Transforms optimizer.step() into an EscapeHatch pattern. Functional optimizers (like MLX/Optax) require explicit update calls opt.update(model, state).
- Parameters:
node (cst.Call) – Original CST call.
ctx (HookContext) – Hook execution context.
- Returns:
The node wrapped in an EscapeHatch.
- Return type:
Union[cst.Call, cst.FlattenSentinel]
- ml_switcheroo.plugins.mlx_optimizers.transform_mlx_zero_grad(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook: Transforms optimizer.zero_grad() into None (No-Op).
- Parameters:
node (cst.Call) – Original CST call.
ctx (HookContext) – Non-used hook context.
- Returns:
A ‘None’ node execution.
- Return type:
cst.CSTNode