ml_switcheroo.plugins.optimizer_step¶

Plugin for Optimizer Step Translation.

Handles the conversion of imperative optimization steps (PyTorch) to functional state updates (JAX/Optax).

Transformations: 1. Instantiation: Strips model.parameters() from the constructor, as Optax

initializes state separately. * Input: opt = torch.optim.Adam(model.parameters(), lr=0.01) * Output: opt = optax.adam(learning_rate=0.01)

  1. Step Execution: Rewrites step() to the Optax update/apply sequence. * Input: optimizer.step() * Output: updates, opt_state = optimizer.update(grads, opt_state, params) * params = optax.apply_updates(params, updates)

  2. Zero Grad: Strips zero_grad() as JAX handles gradients explicitly via grad or value_and_grad.

Functions¶

transform_optimizer_init(→ libcst.Call)

Hook to rewrite Optimizer instantiation.

transform_optimizer_step(→ Union[libcst.Call, ...)

Hook to rewrite optimizer.step().

strip_zero_grad(→ libcst.CSTNode)

Hook for optimizer.zero_grad().

Module Contents¶

ml_switcheroo.plugins.optimizer_step.transform_optimizer_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call¶

Hook to rewrite Optimizer instantiation. Removes the first argument (parameters) if targeting JAX, as Optax is functional.

ml_switcheroo.plugins.optimizer_step.transform_optimizer_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call | libcst.FlattenSentinel¶

Hook to rewrite optimizer.step().

JAX Target: Emits the functional update pattern. Assumes existence of grads, opt_state, params variables in the local scope. Use EscapeHatch to warn user if variables can’t be inferred.

ml_switcheroo.plugins.optimizer_step.strip_zero_grad(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode¶

Hook for optimizer.zero_grad().

JAX Target: Removes the call (No-op), as JAX gradients are not accumulated by default.