ml_switcheroo.plugins.clipping¶

Plugin for Gradient Clipping.

Addresses the mismatch between: 1. PyTorch: torch.nn.utils.clip_grad_norm_(parameters, max_norm) (In-place, returns norm). 2. JAX/Optax: optax.clip_by_global_norm(max_norm).update(grads, state) (Functional, returns updates).

Decoupling Update: Checks traits.requires_functional_state logic.

Functions¶

transform_grad_clipping(→ libcst.CSTNode)

Hook: Transforms imperative clipping to Optax functional clipping.

Module Contents¶

ml_switcheroo.plugins.clipping.transform_grad_clipping(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode[source]¶

Hook: Transforms imperative clipping to Optax functional clipping.

Trigger: clip_grad_norm_ operation. Target: Frameworks with requires_functional_state=True (e.g. JAX, Flax).

Transformation:

Input: clip_grad_norm_(grads, 1.0) Output: optax.clip_by_global_norm(1.0).update(grads, None)[0]

Parameters:
  • node (cst.Call) – The source call.

  • ctx (HookContext) – Execution context with traits.

Returns:

The transformed node or original if trait not met.