MSELossΒΆ Auto-generated from jax_code_defs PyTorchJAX (Core) PyTorchAPI: torch.nn.functional.mse_lossStrategy: Direct MappingOfficial Docs βJAX (Core)API: optax.l2_lossStrategy: Plugin (loss_reduction)