StackedTransformerΒΆ

Stack of Transformer layers.

Abstract Signature:

StackedTransformer(num_layers: int, model_dims: int, num_heads: int)

PyTorch

API: torch.nn.TransformerEncoder
Strategy: Direct Mapping

PaxML / Praxis

API: praxis.layers.StackedTransformer
Strategy: Direct Mapping