L2NormalizeΒΆ
Normalizes along dimension axis using an L2 norm.
Abstract Signature:
L2Normalize(x: Tensor, axis: int, epsilon: float = 1e-12)
JAX (Core)
API:
βStrategy: Macro '{x} / jnp.sqrt(jnp.maximum(jnp.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'
Apple MLX
API:
βStrategy: Macro '{x} / mx.sqrt(mx.maximum(mx.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'
Flax NNX
API:
βStrategy: Macro '{x} / jnp.sqrt(jnp.maximum(jnp.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'
PaxML / Praxis
API:
βStrategy: Macro '{x} / jnp.sqrt(jnp.maximum(jnp.sum({x}**2, axis={axis}, keepdims=True), {epsilon}))'