ml_switcheroo.plugins.optimizer_step ==================================== .. py:module:: ml_switcheroo.plugins.optimizer_step .. autoapi-nested-parse:: Plugin for Optimizer Step Translation. Handles the conversion of imperative optimization steps (PyTorch) to functional state updates (JAX/Optax). Transformations: 1. **Instantiation**: Strips `model.parameters()` from the constructor, as Optax initializes state separately. * Input: `opt = torch.optim.Adam(model.parameters(), lr=0.01)` * Output: `opt = optax.adam(learning_rate=0.01)` 2. **Step Execution**: Rewrites `step()` to the Optax update/apply sequence. * Input: `optimizer.step()` * Output: `updates, opt_state = optimizer.update(grads, opt_state, params)` * `params = optax.apply_updates(params, updates)` 3. **Zero Grad**: Strips `zero_grad()` as JAX handles gradients explicitly via `grad` or `value_and_grad`. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.optimizer_step.transform_optimizer_init ml_switcheroo.plugins.optimizer_step.transform_optimizer_step ml_switcheroo.plugins.optimizer_step.strip_zero_grad Module Contents --------------- .. py:function:: transform_optimizer_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Hook to rewrite Optimizer instantiation. Removes the first argument (parameters) if targeting JAX, as Optax is functional. .. py:function:: transform_optimizer_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> Union[libcst.Call, libcst.FlattenSentinel] Hook to rewrite `optimizer.step()`. JAX Target: Emits the functional update pattern. Assumes existence of `grads`, `opt_state`, `params` variables in the local scope. Use EscapeHatch to warn user if variables can't be inferred. .. py:function:: strip_zero_grad(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.CSTNode Hook for `optimizer.zero_grad()`. JAX Target: Removes the call (No-op), as JAX gradients are not accumulated by default.