GroupedQueryAttentionΒΆ

Dot-product attention sharing keys and values across heads.

Abstract Signature:

GroupedQueryAttention(embed_dim: int, num_heads: int, num_kv_heads: int)

PyTorch

API: β€”
Strategy: Custom / Partial

Keras

API: keras.layers.MultiHeadAttention
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.MultiHeadAttention
Strategy: Direct Mapping

PaxML / Praxis

API: paxml.layers.GroupedQueryAttention
Strategy: Direct Mapping