MultiheadAttentionΒΆ

Auto-generated from flax_nnx_code_defs

PyTorch

API: torch.nn.MultiheadAttention
Strategy: Direct Mapping

Keras

API: keras.layers.MultiHeadAttention
Strategy: Plugin (repack_attn_keras)

TensorFlow

API: keras.layers.MultiHeadAttention
Strategy: Direct Mapping

Apple MLX

API: mlx.nn.layers.transformer.MultiHeadAttention
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.MultiHeadAttention
Strategy: Plugin (repack_attn_flax)