PyTorch
API: —
Strategy: Macro 'torch.nn.functional.mse_loss(torch.log1p({y_true}), torch.log1p({y_pred}), reduction='none')'
JAX (Core)
API: —
Strategy: Macro 'optax.squared_error(jax.numpy.log1p({y_true}), jax.numpy.log1p({y_pred}))'
Keras
API: keras.losses.mean_squared_logarithmic_error
Strategy: Direct Mapping
TensorFlow
API: tf.keras.losses.mean_squared_logarithmic_error
Strategy: Direct Mapping
Flax NNX
API: —
Strategy: Macro 'optax.squared_error(jax.numpy.log1p({y_true}), jax.numpy.log1p({y_pred}))'
PaxML / Praxis
API: —
Strategy: Macro 'optax.squared_error(jax.numpy.log1p({y_true}), jax.numpy.log1p({y_pred}))'