L2NormalizeΒΆ

Normalizes along dimension axis using an L2 norm.

Abstract Signature:

L2Normalize(x: Tensor, axis: int, epsilon: float = 1e-12)

PyTorch

API: torch.nn.functional.normalize
Strategy: Direct Mapping

JAX (Core)

API: β€”
Strategy: Macro '{x} / jnp.sqrt(jnp.maximum(jnp.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'

Keras

API: keras.utils.normalize
Strategy: Direct Mapping

TensorFlow

API: tf.math.l2_normalize
Strategy: Direct Mapping

Apple MLX

API: β€”
Strategy: Macro '{x} / mx.sqrt(mx.maximum(mx.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'

Flax NNX

API: β€”
Strategy: Macro '{x} / jnp.sqrt(jnp.maximum(jnp.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'

PaxML / Praxis

API: β€”
Strategy: Macro '{x} / jnp.sqrt(jnp.maximum(jnp.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'