PyTorch
API: —
Strategy: Macro 'torch.mean(torch.clamp(1 - {y_true} * {y_pred}, min=0) ** 2, dim=-1)'
JAX (Core)
API: —
Strategy: Macro 'jax.numpy.mean(jax.numpy.square(jax.numpy.maximum(1 - {y_true} * {y_pred}, 0)), axis=-1)'
Keras
API: keras.losses.squared_hinge
Strategy: Direct Mapping
TensorFlow
API: tf.keras.losses.squared_hinge
Strategy: Direct Mapping
Flax NNX
API: —
Strategy: Macro 'jax.numpy.mean(jax.numpy.square(jax.numpy.maximum(1 - {y_true} * {y_pred}, 0)), axis=-1)'
PaxML / Praxis
API: —
Strategy: Macro 'jax.numpy.mean(jax.numpy.square(jax.numpy.maximum(1 - {y_true} * {y_pred}, 0)), axis=-1)'