AdaptedTransformerFeedForwardΒΆ
A wrapper for MultitaskResidualAdapter inserted before residual connections.
Abstract Signature:
AdaptedTransformerFeedForward(input_dims: int = 0, hidden_dims: int = 0, dropout_prob: float = 0.0)
PyTorch
API:
βStrategy: Custom / Partial
Flax NNX
API:
βStrategy: Custom / Partial
PaxML / Praxis
API:
praxis.layers.AdaptedTransformerFeedForwardStrategy: Direct Mapping