AdaptedTransformerFeedForwardΒΆ

A wrapper for MultitaskResidualAdapter inserted before residual connections.

Abstract Signature:

AdaptedTransformerFeedForward(input_dims: int = 0, hidden_dims: int = 0, dropout_prob: float = 0.0)

PyTorch

API: β€”
Strategy: Custom / Partial

Flax NNX

API: β€”
Strategy: Custom / Partial

PaxML / Praxis

API: praxis.layers.AdaptedTransformerFeedForward
Strategy: Direct Mapping