ml_switcheroo.plugins.clipping ============================== .. py:module:: ml_switcheroo.plugins.clipping .. autoapi-nested-parse:: Plugin for Gradient Clipping. Addresses the mismatch between: 1. PyTorch: `torch.nn.utils.clip_grad_norm_(parameters, max_norm)` (In-place, returns norm). 2. JAX/Optax: `optax.clip_by_global_norm(max_norm).update(grads, state)` (Functional, returns updates). Transformation: Input: `clip_grad_norm_(grads, 1.0)` Output: `optax.clip_by_global_norm(1.0).update(grads, None)[0]` Limitations: - **In-place mutation**: PyTorch modifies gradients in-place. JAX requires reassignment (`grads = ...`). This plugin generates the expression for the clipped gradients. It relies on the user or surrounding rewriting logic to ensure this result is assigned back to `grads`. - **Return Value**: PyTorch returns the Total Norm. Optax returns the Clipped Gradients. If the original code uses the return value (e.g. for logging `total_norm`), this translation changes semantics. - **Parameters**: Assumes the first argument passed corresponds to the gradient PyTree in JAX. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.clipping.transform_grad_clipping Module Contents --------------- .. py:function:: transform_grad_clipping(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.CSTNode Hook: Transforms imperative clipping to Optax functional clipping. Trigger: `clip_grad_norm_` operation. Target: JAX, Flax.