ml_switcheroo.plugins.attention_packing¶
Plugin for MultiHead Attention Argument Alignment.
Handles the divergence in call signatures for Attention layers: - Reorders (Query, Key, Value) tuples. - Maps key_padding_mask (Torch: True=Masked) to mask (Keras/Flax: True=Keep). - Handles ‘packed’ inputs vs separate arguments.
Functions¶
|
Plugin Hook: Repacks arguments for MultiHeadAttention. |
Module Contents¶
- ml_switcheroo.plugins.attention_packing.repack_attention(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Plugin Hook: Repacks arguments for MultiHeadAttention. Handles both Constructor (Init) and Forward Call patterns.