PyTorch
API: torch.nn.functional.kl_div
Strategy: Macro 'torch.nn.functional.kl_div(({y_pred}).log(), {y_true}, reduction='none').sum(dim=-1)'
JAX (Core)
API: optax.kl_divergence
Strategy: Macro 'optax.kl_divergence(optax.log_softmax({y_pred}), {y_true})'
Keras
API: keras.losses.kl_divergence
Strategy: Direct Mapping
TensorFlow
API: tf.keras.losses.kl_divergence
Strategy: Direct Mapping
Apple MLX
API: mlx.nn.losses.kl_div_loss
Strategy: Direct Mapping
Flax NNX
API: optax.kl_divergence
Strategy: Direct Mapping
PaxML / Praxis
API: optax.kl_divergence
Strategy: Direct Mapping