PyTorch
API: —
Strategy: Macro '1 - (2 * ({y_true} * {y_pred}).sum(dim={axis})) / ({y_true}.sum(dim={axis}) + {y_pred}.sum(dim={axis}))'
JAX (Core)
API: —
Strategy: Macro '1 - (2 * ({y_true} * {y_pred}).sum(axis={axis})) / ({y_true}.sum(axis={axis}) + {y_pred}.sum(axis={axis}))'
Keras
API: keras.losses.dice
Strategy: Direct Mapping
TensorFlow
API: tf.keras.losses.dice
Strategy: Direct Mapping
Flax NNX
API: —
Strategy: Macro '1 - (2 * ({y_true} * {y_pred}).sum(axis={axis})) / ({y_true}.sum(axis={axis}) + {y_pred}.sum(axis={axis}))'
PaxML / Praxis
API: —
Strategy: Macro '1 - (2 * ({y_true} * {y_pred}).sum(axis={axis})) / ({y_true}.sum(axis={axis}) + {y_pred}.sum(axis={axis}))'