ml_switcheroo.plugins.decompositions ==================================== .. py:module:: ml_switcheroo.plugins.decompositions .. autoapi-nested-parse:: Plugin module defining AST decomposition and recomposition rules. This module provides hooks to transform complex function calls into simpler primitives (Decomposition) or reconstruct complex calls from primitives (Recomposition/Composition) to support bidirectional transpilation. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.decompositions.transform_alpha_add ml_switcheroo.plugins.decompositions.transform_alpha_add_reverse Module Contents --------------- .. py:function:: transform_alpha_add(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Transforms an `add` call with an `alpha` parameter into a multiplication. For example, converting `torch.add(x, y, alpha=a)` to `jax.numpy.add(x, y * a)`. Transformation: Input: add(x, y, alpha=a) Output: target_api(x, y * a) :param node: The CST Call node to transform. :param ctx: The plugin execution context. :returns: The transformed CST Call node. .. py:function:: transform_alpha_add_reverse(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Transforms a multiplication-nested `add` call into an `add` with `alpha`. Transformation: Input: target_api(x, y * a) Output: torch.add(x, y, alpha=a) [via mapping] :param node: The CST Call node to transform. :param ctx: The plugin execution context. :returns: The transformed CST Call node.