DotProductAttentionΒΆ

Computes dot-product attention given query, key, and value.

Abstract Signature:

DotProductAttention(query: Array, key: Array, value: Array, bias, mask, dropout_rate: float = 0.0, is_causal: bool = False)

PyTorch

API: torch.nn.functional.scaled_dot_product_attention
Strategy: Direct Mapping

JAX (Core)

API: jax.nn.dot_product_attention
Strategy: Direct Mapping

Keras

API: keras.layers.Attention
Strategy: Macro 'keras.layers.Attention(dropout={dropout_rate})({query}, {value}, key={key}, attention_mask={mask})'

Apple MLX

API: mlx.core.fast.scaled_dot_product_attention
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.dot_product_attention
Strategy: Direct Mapping