ml_switcheroo.plugins.attention_packing ======================================= .. py:module:: ml_switcheroo.plugins.attention_packing .. autoapi-nested-parse:: Plugin for MultiHead Attention Argument Alignment. Handles the divergence in call signatures for Attention layers: - Reorders (Query, Key, Value) tuples. - Maps `key_padding_mask` (Torch: True=Masked) to `mask` (Keras/Flax: True=Keep). - Handles 'packed' inputs vs separate arguments. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.attention_packing.repack_attention Module Contents --------------- .. py:function:: repack_attention(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Plugin Hook: Repacks arguments for MultiHeadAttention. Handles both Constructor (Init) and Forward Call patterns.