ml_switcheroo.plugins.optimizer_step¶
Plugin for Optimizer Step Translation.
Handles the conversion of imperative optimization steps (e.g., PyTorch) to functional state updates (e.g., JAX/Optax).
This logic is Wired-Only: It executes blindly if the semantic map requests it.
Transformations:
Instantiation (`optimizer_constructor`):
Strips the first argument (commonly model.parameters() in Torch) because functional optimizers (Optax) are initialized stateless/factory-style.
Input: opt = torch.optim.Adam(model.parameters(), lr=0.01)
Output: opt = optax.adam(lr=0.01)
Step Execution (`optimizer_step`):
Flags step() calls as requiring manual intervention or functional rewrite.
Output: An EscapeHatch warning block suggesting the update pattern.
Zero Grad (`optimizer_zero_grad`):
Strips the call completely (No-Op), as functional gradients don’t accumulate state.
Functions¶
|
Hook to rewrite Optimizer instantiation. |
|
Hook to rewrite |
|
Hook for |
Module Contents¶
- ml_switcheroo.plugins.optimizer_step.transform_optimizer_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Hook to rewrite Optimizer instantiation.
Removes the first argument (parameters) to support factory-pattern initialization.
- Parameters:
node – Original CST call.
ctx – Hook context.
- Returns:
Transformed CST call.
- ml_switcheroo.plugins.optimizer_step.transform_optimizer_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call | libcst.FlattenSentinel[source]¶
Hook to rewrite
optimizer.step().Since step() logic implies side-effects on the optimizer state and parameters, which doesn’t translate 1:1 to functional updates without knowing variable names (params, grads, opt_state), this hook emits a specialized Escape Hatch.
- Parameters:
node – Original CST call.
ctx – Hook context.
- Returns:
CST node wrapped with escape hatch comments.
- ml_switcheroo.plugins.optimizer_step.strip_zero_grad(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook for
optimizer.zero_grad().Removes the call (No-op), as gradient accumulation is generally explicit in functional frameworks.
- Parameters:
node – Original CST call.
ctx – Hook context.
- Returns:
A CST Name(‘None’) representing a no-op expression.