ml_switcheroo.plugins.schedulers¶

Plugin for Learning Rate Scheduler Rewiring.

Addresses the architectural difference between: 1. PyTorch: scheduler = StepLR(optimizer, step_size=30) (Stateful object wrapping optimizer). 2. JAX/Optax: schedule_fn = optax.piecewise_constant(…) (Pure function passed to optimizer).

Transformation: 1. Instantiation:

  • Detects specific Scheduler constructors.

  • Removes the optimizer argument (Arg 0).

  • Maps hyperparameters to Optax equivalents.

  • Changes the API call to an Optax schedule factory.

  • Injects init_value=1.0 (as partial schedule) or tries to preserve semantics.

  1. Stepping: - Detects scheduler.step(). - Since JAX schedulers are integrated into the gradient transform chain and stepped automatically via state,

    manual stepping is redundant.

    • Replaces scheduler.step() with a no-op placeholder None.

Supported Mappings: - StepLR -> optax.exponential_decay(staircase=True) - CosineAnnealingLR -> optax.cosine_decay_schedule

Functions¶

transform_scheduler_init(→ libcst.CSTNode)

Hook: Transforms Scheduler instantiation.

transform_scheduler_step(→ libcst.CSTNode)

Hook: Replaces scheduler.step() with a no-op value (None).

Module Contents¶

ml_switcheroo.plugins.schedulers.transform_scheduler_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode¶

Hook: Transforms Scheduler instantiation.

ml_switcheroo.plugins.schedulers.transform_scheduler_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode¶

Hook: Replaces scheduler.step() with a no-op value (None).