ml_switcheroo.core.rewriter.calls.utils

Utility functions for Call Rewriting.

Validates function patterns and handles structural injections like state objects, shims, and layout permutations.

Functions

is_functional_apply(→ bool)

Detects if a call node matches the obj.apply pattern used in Flax Linen.

rewrite_stateful_call(→ libcst.Call)

Rewrites a call to a stateful object (Functional patterns only).

inject_rngs_kwarg(→ libcst.Call)

Injects rngs=rngs into a constructor call.

strip_kwarg(→ libcst.Call)

Removes a keyword argument from a call node.

is_super_call(→ bool)

Helper to identify direct super() usage or super().__init__().

is_builtin(→ bool)

Avoid spamming logs for standard python builtins unless mapped.

log_diff(→ None)

Helper to compute diff and log if changed.

compute_permutation(→ Optional[Tuple[int, Ellipsis]])

Computes permutation indices to transform source layout to target.

inject_custom_api_call(→ libcst.Call)

Constructs a generic Call node.

inject_permute_call(→ libcst.CSTNode)

Wraps a CST node with a permutation call valid for the target framework.

Module Contents

ml_switcheroo.core.rewriter.calls.utils.is_functional_apply(node: libcst.Call) bool

Detects if a call node matches the obj.apply pattern used in Flax Linen.

ml_switcheroo.core.rewriter.calls.utils.rewrite_stateful_call(rewriter, node: libcst.Call, instance_name: str, config: Dict[str, str]) libcst.Call

Rewrites a call to a stateful object (Functional patterns only).

ml_switcheroo.core.rewriter.calls.utils.inject_rngs_kwarg(node: libcst.Call) libcst.Call

Injects rngs=rngs into a constructor call.

ml_switcheroo.core.rewriter.calls.utils.strip_kwarg(node: libcst.Call, kw_name: str) libcst.Call

Removes a keyword argument from a call node.

ml_switcheroo.core.rewriter.calls.utils.is_super_call(node: libcst.Call) bool

Helper to identify direct super() usage or super().__init__().

ml_switcheroo.core.rewriter.calls.utils.is_builtin(name: str) bool

Avoid spamming logs for standard python builtins unless mapped.

ml_switcheroo.core.rewriter.calls.utils.log_diff(label: str, original: libcst.CSTNode, modified: libcst.CSTNode) None

Helper to compute diff and log if changed.

ml_switcheroo.core.rewriter.calls.utils.compute_permutation(source_layout: str, target_layout: str) Tuple[int, Ellipsis] | None

Computes permutation indices to transform source layout to target.

Example

Source: “NCHW”, Target: “NHWC” Map: N:0, C:1, H:2, W:3 Target Required: N(0), H(2), W(3), C(1) Result: (0, 2, 3, 1)

Parameters:
  • source_layout – Source layout string (e.g. “NCHW”).

  • target_layout – Target layout string (e.g. “NHWC”).

Returns:

Tuple of integer indices, or None if invalid.

ml_switcheroo.core.rewriter.calls.utils.inject_custom_api_call(func_name_node: libcst.BaseExpression, args: List[libcst.Arg]) libcst.Call

Constructs a generic Call node.

ml_switcheroo.core.rewriter.calls.utils.inject_permute_call(base_node: libcst.CSTNode, indices: Tuple[int, Ellipsis], semantics: ml_switcheroo.semantics.manager.SemanticsManager, target_fw: str) libcst.CSTNode

Wraps a CST node with a permutation call valid for the target framework.

Logic:
  1. Finds permute_dims definition in Semantics for the target framework.

  2. Determines API name (e.g. jnp.transpose, torch.permute, tf.transpose).

  3. Determines packing strategy (tuple kwarg vs varargs).

  4. Wraps base_node.

Parameters:
  • base_node – The expression to verify/permute.

  • indices – Tuple of dimensions to permute (e.g., (0, 2, 3, 1)).

  • semantics – Manager to look up permute_dims syntax.

  • target_fw – Target framework key.

Returns:

CST Node representing permute(base_node, indices).