ml_switcheroo.plugins.clipping¶
Plugin for Gradient Clipping.
Addresses the mismatch between: 1. PyTorch: torch.nn.utils.clip_grad_norm_(parameters, max_norm) (In-place, returns norm). 2. JAX/Optax: optax.clip_by_global_norm(max_norm).update(grads, state) (Functional, returns updates).
- Transformation:
Input: clip_grad_norm_(grads, 1.0) Output: optax.clip_by_global_norm(1.0).update(grads, None)[0]
- Limitations:
In-place mutation: PyTorch modifies gradients in-place. JAX requires reassignment (grads = …). This plugin generates the expression for the clipped gradients. It relies on the user or surrounding rewriting logic to ensure this result is assigned back to grads.
Return Value: PyTorch returns the Total Norm. Optax returns the Clipped Gradients. If the original code uses the return value (e.g. for logging total_norm), this translation changes semantics.
Parameters: Assumes the first argument passed corresponds to the gradient PyTree in JAX.
Functions¶
|
Hook: Transforms imperative clipping to Optax functional clipping. |
Module Contents¶
- ml_switcheroo.plugins.clipping.transform_grad_clipping(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Hook: Transforms imperative clipping to Optax functional clipping.
Trigger: clip_grad_norm_ operation. Target: JAX, Flax.