ml_switcheroo.plugins.schedulers¶
Plugin for Learning Rate Scheduler Rewiring.
Addresses the architectural difference between:
PyTorch: scheduler = StepLR(optimizer, step_size=30) (Stateful object wrapping optimizer).
JAX/Optax: schedule_fn = optax.piecewise_constant(…) (Pure function passed to optimizer).
Transformation
Instantiation:
Detects specific Scheduler constructors via scheduler_rewire.
Removes the optimizer argument (Arg 0).
Maps hyperparameters to target equivalents using keys defined in the semantic map (e.g. step_size -> transition_steps for Optax or decay_steps for Keras).
Changes the API call to the target factory declared in the Knowledge Base.
Injects init_value=1.0 (as partial schedule) or tries to preserve semantics.
Stepping:
Detects scheduler.step() via scheduler_step_noop.
Replaces scheduler.step() with a no-op placeholder None.
Functions¶
|
Hook: Transforms Scheduler instantiation. |
|
Hook: Replaces |
Module Contents¶
- ml_switcheroo.plugins.schedulers.transform_scheduler_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook: Transforms Scheduler instantiation.
Logic routes based on detected Operation ID in context (StepLR vs Cosine). Now fully decoupled: reads target API and argument names from ctx.
- Parameters:
node – Original CST call.
ctx – Hook context containing operation ID.
- Returns:
Transformed CST call.
- ml_switcheroo.plugins.schedulers.transform_scheduler_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook: Replaces
scheduler.step()with a no-op value (None). Triggered if the scheduler step operation is wired to scheduler_step_noop.