ml_switcheroo.plugins.optimizer_step¶
Plugin for Optimizer Step Translation.
Handles the conversion of imperative optimization steps (PyTorch) to functional state updates (JAX/Optax).
Transformations: 1. Instantiation: Strips model.parameters() from the constructor, as Optax
initializes state separately. * Input: opt = torch.optim.Adam(model.parameters(), lr=0.01) * Output: opt = optax.adam(learning_rate=0.01)
Step Execution: Rewrites step() to the Optax update/apply sequence. * Input: optimizer.step() * Output: updates, opt_state = optimizer.update(grads, opt_state, params) * params = optax.apply_updates(params, updates)
Zero Grad: Strips zero_grad() as JAX handles gradients explicitly via grad or value_and_grad.
Functions¶
|
Hook to rewrite Optimizer instantiation. |
|
Hook to rewrite optimizer.step(). |
|
Hook for optimizer.zero_grad(). |
Module Contents¶
- ml_switcheroo.plugins.optimizer_step.transform_optimizer_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Hook to rewrite Optimizer instantiation. Removes the first argument (parameters) if targeting JAX, as Optax is functional.
- ml_switcheroo.plugins.optimizer_step.transform_optimizer_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call | libcst.FlattenSentinel¶
Hook to rewrite optimizer.step().
JAX Target: Emits the functional update pattern. Assumes existence of grads, opt_state, params variables in the local scope. Use EscapeHatch to warn user if variables can’t be inferred.
- ml_switcheroo.plugins.optimizer_step.strip_zero_grad(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Hook for optimizer.zero_grad().
JAX Target: Removes the call (No-op), as JAX gradients are not accumulated by default.