ml_switcheroo.plugins.loss_wrapper

Plugin for Loss Reduction Semantics.

Addresses the mismatch between:

  1. PyTorch: loss = F.cross_entropy(…, reduction=’mean’) (Scalar output by default).

  2. Functional Frameworks (JAX/Optax): loss = optax.softmax_cross_entropy(x, y) (Vector output per batch element).

JAX libraries typically return the loss per sample to support vmap/pmap flexibility. PyTorch defaults to averaging (mean) immediately.

Transformation:

  1. Detects reduction keyword argument.

  2. Strips the argument (as Optax/JAX funcs don’t usually accept it).

  3. Wraps the function call in Mean(x) or Sum(x). This step dynamically looks up the “Mean” or “Sum” API from the Semantic Knowledge Base, supporting any target framework definition (e.g. tf.reduce_mean, jnp.mean, mx.mean).

  4. If reduction=’none’, leaves the vector output alone.

Functions

transform_loss_reduction(→ libcst.CSTNode)

Hook: Wraps loss functions to apply reduction.

Module Contents

ml_switcheroo.plugins.loss_wrapper.transform_loss_reduction(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]

Hook: Wraps loss functions to apply reduction.

Trigger: Operations mapped with requires_plugin: “loss_reduction”. Target: Frameworks requiring explicit reduction (JAX, Flax).

Parameters:
  • node – The original CST Call node.

  • ctx – HookContext for API lookup.

Returns:

Transformed Call node (wrapped or unwrapped).