ml_switcheroo.plugins.loss_wrapper¶
Plugin for Loss Reduction Semantics.
Addresses the mismatch between: 1. PyTorch: loss = F.cross_entropy(x, y, reduction=’mean’) (Scalar output by default). 2. JAX/Optax: loss = optax.softmax_cross_entropy(x, y) (Vector output per batch).
JAX libraries typically return the loss per sample to support vmap/pmap flexibility. PyTorch defaults to averaging (mean) immediately.
Transformation: 1. Detects reduction keyword argument. 2. Strips the argument (as Optax/JAX funcs don’t usually accept it). 3. Wraps the function call in jnp.mean() (default/mean) or jnp.sum(). 4. If reduction=’none’, leaves the vector output alone.
Functions¶
|
Hook: Wraps loss functions to apply reduction. |
Module Contents¶
- ml_switcheroo.plugins.loss_wrapper.transform_loss_reduction(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Hook: Wraps loss functions to apply reduction.
Trigger: Operations mapped with requires_plugin: “loss_reduction”. Target: JAX, Flax.