ml_switcheroo.plugins.optimizer_step¶

Plugin for Optimizer Step Translation.

Handles the conversion of imperative optimization steps (e.g., PyTorch) to functional state updates (e.g., JAX/Optax).

This logic is Wired-Only: It executes blindly if the semantic map requests it.

Transformations:

  1. Instantiation (`optimizer_constructor`):

    • Strips the first argument (commonly model.parameters() in Torch) because functional optimizers (Optax) are initialized stateless/factory-style.

    • Input: opt = torch.optim.Adam(model.parameters(), lr=0.01)

    • Output: opt = optax.adam(lr=0.01)

  2. Step Execution (`optimizer_step`):

    • Flags step() calls as requiring manual intervention or functional rewrite.

    • Output: An EscapeHatch warning block suggesting the update pattern.

  3. Zero Grad (`optimizer_zero_grad`):

    • Strips the call completely (No-Op), as functional gradients don’t accumulate state.

Functions¶

transform_optimizer_init(→ libcst.Call)

Hook to rewrite Optimizer instantiation.

transform_optimizer_step(→ Union[libcst.Call, ...)

Hook to rewrite optimizer.step().

strip_zero_grad(→ libcst.CSTNode)

Hook for optimizer.zero_grad().

Module Contents¶

ml_switcheroo.plugins.optimizer_step.transform_optimizer_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call[source]¶

Hook to rewrite Optimizer instantiation.

Removes the first argument (parameters) to support factory-pattern initialization.

Parameters:
  • node – Original CST call.

  • ctx – Hook context.

Returns:

Transformed CST call.

ml_switcheroo.plugins.optimizer_step.transform_optimizer_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call | libcst.FlattenSentinel[source]¶

Hook to rewrite optimizer.step().

Since step() logic implies side-effects on the optimizer state and parameters, which doesn’t translate 1:1 to functional updates without knowing variable names (params, grads, opt_state), this hook emits a specialized Escape Hatch.

Parameters:
  • node – Original CST call.

  • ctx – Hook context.

Returns:

CST node wrapped with escape hatch comments.

ml_switcheroo.plugins.optimizer_step.strip_zero_grad(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode[source]¶

Hook for optimizer.zero_grad().

Removes the call (No-op), as gradient accumulation is generally explicit in functional frameworks.

Parameters:
  • node – Original CST call.

  • ctx – Hook context.

Returns:

A CST Name(‘None’) representing a no-op expression.