ml_switcheroo.plugins.clipping¶
Plugin for Gradient Clipping.
Addresses the mismatch between: 1. PyTorch: torch.nn.utils.clip_grad_norm_(parameters, max_norm) (In-place, returns norm). 2. JAX/Optax: optax.clip_by_global_norm(max_norm).update(grads, state) (Functional, returns updates).
Decoupling Update: Checks traits.requires_functional_state logic.
Functions¶
|
Hook: Transforms imperative clipping to Optax functional clipping. |
Module Contents¶
- ml_switcheroo.plugins.clipping.transform_grad_clipping(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook: Transforms imperative clipping to Optax functional clipping.
Trigger: clip_grad_norm_ operation. Target: Frameworks with requires_functional_state=True (e.g. JAX, Flax).
- Transformation:
Input: clip_grad_norm_(grads, 1.0) Output: optax.clip_by_global_norm(1.0).update(grads, None)[0]
- Parameters:
node (cst.Call) – The source call.
ctx (HookContext) – Execution context with traits.
- Returns:
The transformed node or original if trait not met.