CosineEmbeddingLossΒΆ
Creates a criterion that measures the loss given input tensors x1, x2 and a Tensor label y.
Abstract Signature:
CosineEmbeddingLoss(input1: Tensor, input2: Tensor, target: Tensor, margin: float = 0.0, reduction: str = mean)
JAX (Core)
API:
optax.cosine_similarity_lossStrategy: Plugin (loss_wrapper)