ml_switcheroo.plugins.decompositions¶

Plugin module defining AST decomposition and recomposition rules.

This module provides hooks to transform complex function calls into simpler primitives (Decomposition) or reconstruct complex calls from primitives (Recomposition/Composition) to support bidirectional transpilation.

Functions¶

transform_alpha_add(→ libcst.Call)

Transforms an add call with an alpha parameter into a multiplication.

transform_alpha_add_reverse(→ libcst.Call)

Transforms a multiplication-nested add call into an add with alpha.

Module Contents¶

ml_switcheroo.plugins.decompositions.transform_alpha_add(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call¶

Transforms an add call with an alpha parameter into a multiplication.

For example, converting torch.add(x, y, alpha=a) to jax.numpy.add(x, y * a).

Transformation:

Input: add(x, y, alpha=a) Output: target_api(x, y * a)

Parameters:
  • node – The CST Call node to transform.

  • ctx – The plugin execution context.

Returns:

The transformed CST Call node.

ml_switcheroo.plugins.decompositions.transform_alpha_add_reverse(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call¶

Transforms a multiplication-nested add call into an add with alpha.

Transformation:

Input: target_api(x, y * a) Output: torch.add(x, y, alpha=a) [via mapping]

Parameters:
  • node – The CST Call node to transform.

  • ctx – The plugin execution context.

Returns:

The transformed CST Call node.