GroupQueryAttentionΒΆ

Grouped Query Attention layer.

Abstract Signature:

GroupQueryAttention(head_dim: int, num_query_heads: int, num_key_value_heads: int, dropout: float = 0.0)

Keras

API: keras.layers.GroupQueryAttention
Strategy: Direct Mapping

TensorFlow

API: tf.keras.layers.GroupQueryAttention
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.MultiHeadAttention
Strategy: Direct Mapping