KLDivergenceΒΆ
Computes Kullback-Leibler divergence loss between y_true & y_pred.
Abstract Signature:
KLDivergence(y_true: Tensor, y_pred: Tensor)
PyTorch
API:
torch.nn.functional.kl_divStrategy: Macro 'torch.nn.functional.kl_div(({y_pred}).log(), {y_true}, reduction='none').sum(dim=-1)'
JAX (Core)
API:
optax.kl_divergenceStrategy: Macro 'optax.kl_divergence(optax.log_softmax({y_pred}), {y_true})'