PyTorch
API: torch.nn.functional.normalize
Strategy: Direct Mapping
JAX (Core)
API: —
Strategy: Macro '{x} / jnp.sqrt(jnp.maximum(jnp.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'
Keras
API: keras.utils.normalize
Strategy: Direct Mapping
TensorFlow
API: tf.math.l2_normalize
Strategy: Direct Mapping
Apple MLX
API: —
Strategy: Macro '{x} / mx.sqrt(mx.maximum(mx.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'
Flax NNX
API: —
Strategy: Macro '{x} / jnp.sqrt(jnp.maximum(jnp.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'
PaxML / Praxis
API: —
Strategy: Macro '{x} / jnp.sqrt(jnp.maximum(jnp.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'