ml_switcheroo.plugins.attention_packing¶
Plugin for MultiHead Attention Strategy Selection.
This module provides modular hooks for transforming MultiHeadAttention API calls between frameworks. Logic is split into discrete argument-mapping strategies (repack_attn_keras and repack_attn_flax).
- Decoupling Logic:
No Hardcoded Frameworks: The plugin does not contain strings like flax.nnx or keras.layers.
Strict Lookup: Target class names are resolved via ctx.lookup_api.
Safety: If the Knowledge Base is missing a mapping for âMultiheadAttentionâ, constructor transformations are aborted to prevent hallucination.
Functions¶
|
Strategy: Keras Attention Packing. |
|
Strategy: Flax/JAX Attention Packing. |
Module Contents¶
- ml_switcheroo.plugins.attention_packing.repack_attn_keras(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Strategy: Keras Attention Packing.
Constructor: - Requires âMultiheadAttentionâ mapping in Semantics. - Renames embed_dim -> key_dim recursively.
Call (Inference): - Remaps typical Torch signature (q, k, v, mask) to Keras (q, v, key=k, attention_mask=mask).
- Parameters:
node â Original Call node.
ctx â HookContext for API lookup.
- Returns:
Transformed Call node, or original if dependencies missing.
- ml_switcheroo.plugins.attention_packing.repack_attn_flax(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Strategy: Flax/JAX Attention Packing.
Constructor: - Requires âMultiheadAttentionâ mapping in Semantics.
Call (Inference): - Maps key_padding_mask -> mask.
- Parameters:
node â Original Call node.
ctx â HookContext for API lookup.
- Returns:
Transformed Call node, or original if dependencies missing.