UnitNormalizationΒΆ
Normalize a batch of inputs so that each input in the batch has a L2 norm equal to 1.
Abstract Signature:
UnitNormalization(axis: int = -1)
PyTorch
API:
βStrategy: Plugin (unit_norm_layer)
Flax NNX
API:
βStrategy: Plugin (unit_norm_layer)