UnitNormalizationΒΆ

Normalize a batch of inputs so that each input in the batch has a L2 norm equal to 1.

Abstract Signature:

UnitNormalization(axis: int = -1)

PyTorch

API: β€”
Strategy: Plugin (unit_norm_layer)

Keras

API: keras.layers.UnitNormalization
Strategy: Direct Mapping

TensorFlow

API: tf.keras.layers.UnitNormalization
Strategy: Direct Mapping

Flax NNX

API: β€”
Strategy: Plugin (unit_norm_layer)