KlDivΒΆ
The Kullback-Leibler divergence Loss.
Abstract Signature:
KlDiv(input: Tensor, target: Tensor, reduction: str = mean)
JAX (Core)
API:
optax.kl_divergenceStrategy: Plugin (loss_reduction)
NumPy
API:
βStrategy: Custom / Partial
Apple MLX
API:
mx.nn.losses.kl_divStrategy: Plugin (loss_reduction)
Flax NNX
API:
optax.kl_divergenceStrategy: Plugin (loss_reduction)
PaxML / Praxis
API:
optax.kl_divergenceStrategy: Plugin (loss_reduction)