ml_switcheroo.plugins.schedulers¶

Plugin for Learning Rate Scheduler Rewiring.

Addresses the architectural difference between:

  1. PyTorch: scheduler = StepLR(optimizer, step_size=30) (Stateful object wrapping optimizer).

  2. JAX/Optax: schedule_fn = optax.piecewise_constant(…) (Pure function passed to optimizer).

Transformation

  1. Instantiation:

    • Detects specific Scheduler constructors via scheduler_rewire.

    • Removes the optimizer argument (Arg 0).

    • Maps hyperparameters to target equivalents using keys defined in the semantic map (e.g. step_size -> transition_steps for Optax or decay_steps for Keras).

    • Changes the API call to the target factory declared in the Knowledge Base.

    • Injects init_value=1.0 (as partial schedule) or tries to preserve semantics.

  2. Stepping:

    • Detects scheduler.step() via scheduler_step_noop.

    • Replaces scheduler.step() with a no-op placeholder None.

Functions¶

transform_scheduler_init(→ libcst.CSTNode)

Hook: Transforms Scheduler instantiation.

transform_scheduler_step(→ libcst.CSTNode)

Hook: Replaces scheduler.step() with a no-op value (None).

Module Contents¶

ml_switcheroo.plugins.schedulers.transform_scheduler_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode[source]¶

Hook: Transforms Scheduler instantiation.

Logic routes based on detected Operation ID in context (StepLR vs Cosine). Now fully decoupled: reads target API and argument names from ctx.

Parameters:
  • node – Original CST call.

  • ctx – Hook context containing operation ID.

Returns:

Transformed CST call.

ml_switcheroo.plugins.schedulers.transform_scheduler_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode[source]¶

Hook: Replaces scheduler.step() with a no-op value (None). Triggered if the scheduler step operation is wired to scheduler_step_noop.