ml_switcheroo.plugins.loss_wrapper ================================== .. py:module:: ml_switcheroo.plugins.loss_wrapper .. autoapi-nested-parse:: Plugin for Loss Reduction Semantics. Addresses the mismatch between: 1. PyTorch: `loss = F.cross_entropy(x, y, reduction='mean')` (Scalar output by default). 2. JAX/Optax: `loss = optax.softmax_cross_entropy(x, y)` (Vector output per batch). JAX libraries typically return the loss *per sample* to support vmap/pmap flexibility. PyTorch defaults to averaging (`mean`) immediately. Transformation: 1. Detects `reduction` keyword argument. 2. Strips the argument (as Optax/JAX funcs don't usually accept it). 3. Wraps the function call in `jnp.mean()` (default/mean) or `jnp.sum()`. 4. If `reduction='none'`, leaves the vector output alone. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.loss_wrapper.transform_loss_reduction Module Contents --------------- .. py:function:: transform_loss_reduction(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.CSTNode Hook: Wraps loss functions to apply reduction. Trigger: Operations mapped with `requires_plugin: "loss_reduction"`. Target: JAX, Flax.