TransformerFeedForwardΒΆ

Transformer feedforward layer with residual connection and dropout.

Abstract Signature:

TransformerFeedForward(input_dims: int, hidden_dims: int, dropout_prob: float = 0.0)

PyTorch

API: torch.nn.Sequential
Strategy: Plugin (transformer_ff_block)

PaxML / Praxis

API: praxis.layers.TransformerFeedForward
Strategy: Direct Mapping