ml_switcheroo.plugins.einsum ============================ .. py:module:: ml_switcheroo.plugins.einsum .. autoapi-nested-parse:: Plugin for normalizing Einsum calls. Standardizes `einsum` arguments so the equation string is always the first argument. JAX strictly enforces `einsum(equation, *operands)`, whereas other frameworks (like older PyTorch versions or specific utility wrappers) might allow flexible ordering like `einsum(operand, operand, equation)`. This plugin handles: 1. **Equation Identification**: Scans arguments to find the string literal (the equation). 2. **Reordering**: Moves the equation to the 0th position if it isn't already there. 3. **API Renaming**: Updates the function call to the target framework's API (e.g., `jax.numpy.einsum`). 4. **Syntax Cleaning**: Essential comma management when shuffling argument order in the AST. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.einsum.normalize_einsum Module Contents --------------- .. py:function:: normalize_einsum(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Plugin Hook: Rotates arguments to place the equation string first and renames function. Triggers: Operations mapping to `Einsum` with `requires_plugin: "einsum_normalizer"`. Transformation: Input: `torch.einsum(x, y, "ij,jk->ik")` Output: `jax.numpy.einsum("ij,jk->ik", x, y)` :param node: The original CST Call node. :param ctx: HookContext for looking up the target API. :returns: The transformed CST Call node.