ml_switcheroo.plugins.einsum¶

Plugin for normalizing Einsum calls.

Standardizes einsum arguments so the equation string is always the first argument. JAX strictly enforces einsum(equation, *operands), whereas other frameworks (like older PyTorch versions or specific utility wrappers) might allow flexible ordering like einsum(operand, operand, equation).

This plugin handles: 1. Equation Identification: Scans arguments to find the string literal (the equation). 2. Reordering: Moves the equation to the 0th position if it isn’t already there. 3. API Renaming: Updates the function call to the target framework’s API (e.g., jax.numpy.einsum). 4. Syntax Cleaning: Essential comma management when shuffling argument order in the AST.

Functions¶

normalize_einsum(→ libcst.Call)

Plugin Hook: Rotates arguments to place the equation string first and renames function.

Module Contents¶

ml_switcheroo.plugins.einsum.normalize_einsum(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call¶

Plugin Hook: Rotates arguments to place the equation string first and renames function.

Triggers:

Operations mapping to Einsum with requires_plugin: ā€œeinsum_normalizerā€.

Transformation:

Input: torch.einsum(x, y, ā€œij,jk->ikā€) Output: jax.numpy.einsum(ā€œij,jk->ikā€, x, y)

Parameters:
  • node – The original CST Call node.

  • ctx – HookContext for looking up the target API.

Returns:

The transformed CST Call node.