ml_switcheroo.plugins.schedulers ================================ .. py:module:: ml_switcheroo.plugins.schedulers .. autoapi-nested-parse:: Plugin for Learning Rate Scheduler Rewiring. Addresses the architectural difference between: 1. PyTorch: `scheduler = StepLR(optimizer, step_size=30)` (Stateful object wrapping optimizer). 2. JAX/Optax: `schedule_fn = optax.piecewise_constant(...)` (Pure function passed to optimizer). Transformation: 1. **Instantiation**: - Detects specific Scheduler constructors. - Removes the `optimizer` argument (Arg 0). - Maps hyperparameters to Optax equivalents. - Changes the API call to an Optax schedule factory. - Injects `init_value=1.0` (as partial schedule) or tries to preserve semantics. 2. **Stepping**: - Detects `scheduler.step()`. - Since JAX schedulers are integrated into the gradient transform chain and stepped automatically via state, manual stepping is redundant. - Replaces `scheduler.step()` with a no-op placeholder `None`. Supported Mappings: - `StepLR` -> `optax.exponential_decay(staircase=True)` - `CosineAnnealingLR` -> `optax.cosine_decay_schedule` Functions --------- .. autoapisummary:: ml_switcheroo.plugins.schedulers.transform_scheduler_init ml_switcheroo.plugins.schedulers.transform_scheduler_step Module Contents --------------- .. py:function:: transform_scheduler_init(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.CSTNode Hook: Transforms Scheduler instantiation. .. py:function:: transform_scheduler_step(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.CSTNode Hook: Replaces `scheduler.step()` with a no-op value (None).