ml_switcheroo.plugins.schedulers¶
Plugin for Learning Rate Scheduler Rewiring.
Addresses the architectural difference between: 1. PyTorch: scheduler = StepLR(optimizer, step_size=30) (Stateful object wrapping optimizer). 2. JAX/Optax: schedule_fn = optax.piecewise_constant(…) (Pure function passed to optimizer).
Transformation: 1. Instantiation:
Detects specific Scheduler constructors.
Removes the optimizer argument (Arg 0).
Maps hyperparameters to Optax equivalents.
Changes the API call to an Optax schedule factory.
Injects init_value=1.0 (as partial schedule) or tries to preserve semantics.
Stepping: - Detects scheduler.step(). - Since JAX schedulers are integrated into the gradient transform chain and stepped automatically via state,
manual stepping is redundant.
Replaces scheduler.step() with a no-op placeholder None.
Supported Mappings: - StepLR -> optax.exponential_decay(staircase=True) - CosineAnnealingLR -> optax.cosine_decay_schedule
Functions¶
|
Hook: Transforms Scheduler instantiation. |
|
Hook: Replaces scheduler.step() with a no-op value (None). |
Module Contents¶
- ml_switcheroo.plugins.schedulers.transform_scheduler_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Hook: Transforms Scheduler instantiation.
- ml_switcheroo.plugins.schedulers.transform_scheduler_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Hook: Replaces scheduler.step() with a no-op value (None).