MakeAttentionMaskΒΆ
Creates mask for attention weights (Flax utility).
Abstract Signature:
MakeAttentionMask(query_input: Tensor, key_input: Tensor)
PyTorch
API:
βStrategy: Plugin (attention_mask_generator)
Creates mask for attention weights (Flax utility).
Abstract Signature:
MakeAttentionMask(query_input: Tensor, key_input: Tensor)
β