GShardSharedEmbeddingSoftmaxΒΆ
Softmax layer with embedding lookup and Gaussian init used in GShard.
Abstract Signature:
GShardSharedEmbeddingSoftmax(in_features: int, num_classes: int)
PyTorch
API:
βStrategy: Custom / Partial
Keras
API:
βStrategy: Custom / Partial
Flax NNX
API:
βStrategy: Custom / Partial
PaxML / Praxis
API:
paxml.layers.GShardSharedEmbeddingSoftmaxStrategy: Direct Mapping