ml_switcheroo.plugins.attention_packing¶

Plugin for MultiHead Attention Argument Alignment.

Handles the divergence in call signatures for Attention layers: - Reorders (Query, Key, Value) tuples. - Maps key_padding_mask (Torch: True=Masked) to mask (Keras/Flax: True=Keep). - Handles ‘packed’ inputs vs separate arguments.

Functions¶

repack_attention(→ libcst.Call)

Plugin Hook: Repacks arguments for MultiHeadAttention.

Module Contents¶

ml_switcheroo.plugins.attention_packing.repack_attention(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call¶

Plugin Hook: Repacks arguments for MultiHeadAttention. Handles both Constructor (Init) and Forward Call patterns.