MakeAttentionMaskΒΆ

Creates mask for attention weights (Flax utility).

Abstract Signature:

MakeAttentionMask(query_input: Tensor, key_input: Tensor)

PyTorch

API: β€”
Strategy: Plugin (attention_mask_generator)

Flax NNX

API: flax.nnx.make_attention_mask
Strategy: Direct Mapping