ml_switcheroo.plugins.loss_wrapper¶
Plugin for Loss Reduction Semantics.
Addresses the mismatch between:
PyTorch: loss = F.cross_entropy(…, reduction=’mean’) (Scalar output by default).
Functional Frameworks (JAX/Optax): loss = optax.softmax_cross_entropy(x, y) (Vector output per batch element).
JAX libraries typically return the loss per sample to support vmap/pmap flexibility. PyTorch defaults to averaging (mean) immediately.
Transformation:
Detects reduction keyword argument.
Strips the argument (as Optax/JAX funcs don’t usually accept it).
Wraps the function call in Mean(x) or Sum(x). This step dynamically looks up the “Mean” or “Sum” API from the Semantic Knowledge Base, supporting any target framework definition (e.g. tf.reduce_mean, jnp.mean, mx.mean).
If reduction=’none’, leaves the vector output alone.
Functions¶
|
Hook: Wraps loss functions to apply reduction. |
Module Contents¶
- ml_switcheroo.plugins.loss_wrapper.transform_loss_reduction(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook: Wraps loss functions to apply reduction.
Trigger: Operations mapped with requires_plugin: “loss_reduction”. Target: Frameworks requiring explicit reduction (JAX, Flax).
- Parameters:
node – The original CST Call node.
ctx – HookContext for API lookup.
- Returns:
Transformed Call node (wrapped or unwrapped).