ml_switcheroo.plugins.decompositions¶
Plugin module defining AST decomposition and recomposition rules.
This module provides hooks to transform complex function calls into simpler primitives (Decomposition) or reconstruct complex calls from primitives (Recomposition/Composition) to support bidirectional transpilation.
Functions¶
|
Transforms an add call with an alpha parameter into a multiplication. |
|
Transforms a multiplication-nested add call into an add with alpha. |
Module Contents¶
- ml_switcheroo.plugins.decompositions.transform_alpha_add(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Transforms an add call with an alpha parameter into a multiplication.
For example, converting torch.add(x, y, alpha=a) to jax.numpy.add(x, y * a).
- Transformation:
Input: add(x, y, alpha=a) Output: target_api(x, y * a)
- Parameters:
node – The CST Call node to transform.
ctx – The plugin execution context.
- Returns:
The transformed CST Call node.
- ml_switcheroo.plugins.decompositions.transform_alpha_add_reverse(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Transforms a multiplication-nested add call into an add with alpha.
- Transformation:
Input: target_api(x, y * a) Output: torch.add(x, y, alpha=a) [via mapping]
- Parameters:
node – The CST Call node to transform.
ctx – The plugin execution context.
- Returns:
The transformed CST Call node.