ml_switcheroo.plugins.attention_packing¶

Plugin for MultiHead Attention Strategy Selection.

This module provides modular hooks for transforming MultiHeadAttention API calls between frameworks. Logic is split into discrete argument-mapping strategies (repack_attn_keras and repack_attn_flax).

Decoupling Logic:
  • No Hardcoded Frameworks: The plugin does not contain strings like flax.nnx or keras.layers.

  • Strict Lookup: Target class names are resolved via ctx.lookup_api.

  • Safety: If the Knowledge Base is missing a mapping for “MultiheadAttention”, constructor transformations are aborted to prevent hallucination.

Functions¶

repack_attn_keras(→ libcst.Call)

Strategy: Keras Attention Packing.

repack_attn_flax(→ libcst.Call)

Strategy: Flax/JAX Attention Packing.

Module Contents¶

ml_switcheroo.plugins.attention_packing.repack_attn_keras(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call[source]¶

Strategy: Keras Attention Packing.

Constructor: - Requires ‘MultiheadAttention’ mapping in Semantics. - Renames embed_dim -> key_dim recursively.

Call (Inference): - Remaps typical Torch signature (q, k, v, mask) to Keras (q, v, key=k, attention_mask=mask).

Parameters:
  • node – Original Call node.

  • ctx – HookContext for API lookup.

Returns:

Transformed Call node, or original if dependencies missing.

ml_switcheroo.plugins.attention_packing.repack_attn_flax(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call[source]¶

Strategy: Flax/JAX Attention Packing.

Constructor: - Requires ‘MultiheadAttention’ mapping in Semantics.

Call (Inference): - Maps key_padding_mask -> mask.

Parameters:
  • node – Original Call node.

  • ctx – HookContext for API lookup.

Returns:

Transformed Call node, or original if dependencies missing.