PyTorch
API: torch.kl_div
Strategy: Direct Mapping
JAX (Core)
API: optax.kl_divergence
Strategy: Plugin (loss_reduction)
NumPy
API: —
Strategy: Custom / Partial
Keras
API: keras.losses.kl_divergence
Strategy: Direct Mapping
TensorFlow
API: tf.keras.losses.kl_divergence
Strategy: Direct Mapping
Apple MLX
API: mx.nn.losses.kl_div
Strategy: Plugin (loss_reduction)
Flax NNX
API: optax.kl_divergence
Strategy: Plugin (loss_reduction)
PaxML / Praxis
API: optax.kl_divergence
Strategy: Plugin (loss_reduction)