ml_switcheroo.plugins.clipping¶

Plugin for Gradient Clipping.

Addresses the mismatch between: 1. PyTorch: torch.nn.utils.clip_grad_norm_(parameters, max_norm) (In-place, returns norm). 2. JAX/Optax: optax.clip_by_global_norm(max_norm).update(grads, state) (Functional, returns updates).

Transformation:

Input: clip_grad_norm_(grads, 1.0) Output: optax.clip_by_global_norm(1.0).update(grads, None)[0]

Limitations:
  • In-place mutation: PyTorch modifies gradients in-place. JAX requires reassignment (grads = …). This plugin generates the expression for the clipped gradients. It relies on the user or surrounding rewriting logic to ensure this result is assigned back to grads.

  • Return Value: PyTorch returns the Total Norm. Optax returns the Clipped Gradients. If the original code uses the return value (e.g. for logging total_norm), this translation changes semantics.

  • Parameters: Assumes the first argument passed corresponds to the gradient PyTree in JAX.

Functions¶

transform_grad_clipping(→ libcst.CSTNode)

Hook: Transforms imperative clipping to Optax functional clipping.

Module Contents¶

ml_switcheroo.plugins.clipping.transform_grad_clipping(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode¶

Hook: Transforms imperative clipping to Optax functional clipping.

Trigger: clip_grad_norm_ operation. Target: JAX, Flax.