ml_switcheroo.plugins.einsum¶
Plugin for normalizing Einsum calls.
Standardizes einsum arguments so the equation string is always the first argument. JAX strictly enforces einsum(equation, *operands), whereas other frameworks (like older PyTorch versions or specific utility wrappers) might allow flexible ordering like einsum(operand, operand, equation).
This plugin handles: 1. Equation Identification: Scans arguments to find the string literal (the equation). 2. Reordering: Moves the equation to the 0th position if it isnāt already there. 3. API Renaming: Updates the function call to the target frameworkās API (e.g., jax.numpy.einsum). 4. Syntax Cleaning: Essential comma management when shuffling argument order in the AST.
Functions¶
|
Plugin Hook: Rotates arguments to place the equation string first and renames function. |
Module Contents¶
- ml_switcheroo.plugins.einsum.normalize_einsum(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Plugin Hook: Rotates arguments to place the equation string first and renames function.
- Triggers:
Operations mapping to Einsum with requires_plugin: āeinsum_normalizerā.
- Transformation:
Input: torch.einsum(x, y, āij,jk->ikā) Output: jax.numpy.einsum(āij,jk->ikā, x, y)
- Parameters:
node ā The original CST Call node.
ctx ā HookContext for looking up the target API.
- Returns:
The transformed CST Call node.