ml_switcheroo.core.rewriter.calls.utils¶
Utility functions for Call Rewriting.
Validates function patterns and handles structural injections like state objects, shims, and layout permutations.
Functions¶
|
Detects if a call node matches the obj.apply pattern used in Flax Linen. |
|
Rewrites a call to a stateful object (Functional patterns only). |
|
Injects rngs=rngs into a constructor call. |
|
Removes a keyword argument from a call node. |
|
Helper to identify direct super() usage or super().__init__(). |
|
Avoid spamming logs for standard python builtins unless mapped. |
|
Helper to compute diff and log if changed. |
|
Computes permutation indices to transform source layout to target. |
|
Constructs a generic Call node. |
|
Wraps a CST node with a permutation call valid for the target framework. |
Module Contents¶
- ml_switcheroo.core.rewriter.calls.utils.is_functional_apply(node: libcst.Call) bool¶
Detects if a call node matches the obj.apply pattern used in Flax Linen.
- ml_switcheroo.core.rewriter.calls.utils.rewrite_stateful_call(rewriter, node: libcst.Call, instance_name: str, config: Dict[str, str]) libcst.Call¶
Rewrites a call to a stateful object (Functional patterns only).
- ml_switcheroo.core.rewriter.calls.utils.inject_rngs_kwarg(node: libcst.Call) libcst.Call¶
Injects rngs=rngs into a constructor call.
- ml_switcheroo.core.rewriter.calls.utils.strip_kwarg(node: libcst.Call, kw_name: str) libcst.Call¶
Removes a keyword argument from a call node.
- ml_switcheroo.core.rewriter.calls.utils.is_super_call(node: libcst.Call) bool¶
Helper to identify direct super() usage or super().__init__().
- ml_switcheroo.core.rewriter.calls.utils.is_builtin(name: str) bool¶
Avoid spamming logs for standard python builtins unless mapped.
- ml_switcheroo.core.rewriter.calls.utils.log_diff(label: str, original: libcst.CSTNode, modified: libcst.CSTNode) None¶
Helper to compute diff and log if changed.
- ml_switcheroo.core.rewriter.calls.utils.compute_permutation(source_layout: str, target_layout: str) Tuple[int, Ellipsis] | None¶
Computes permutation indices to transform source layout to target.
Example
Source: “NCHW”, Target: “NHWC” Map: N:0, C:1, H:2, W:3 Target Required: N(0), H(2), W(3), C(1) Result: (0, 2, 3, 1)
- Parameters:
source_layout – Source layout string (e.g. “NCHW”).
target_layout – Target layout string (e.g. “NHWC”).
- Returns:
Tuple of integer indices, or None if invalid.
- ml_switcheroo.core.rewriter.calls.utils.inject_custom_api_call(func_name_node: libcst.BaseExpression, args: List[libcst.Arg]) libcst.Call¶
Constructs a generic Call node.
- ml_switcheroo.core.rewriter.calls.utils.inject_permute_call(base_node: libcst.CSTNode, indices: Tuple[int, Ellipsis], semantics: ml_switcheroo.semantics.manager.SemanticsManager, target_fw: str) libcst.CSTNode¶
Wraps a CST node with a permutation call valid for the target framework.
- Logic:
Finds permute_dims definition in Semantics for the target framework.
Determines API name (e.g. jnp.transpose, torch.permute, tf.transpose).
Determines packing strategy (tuple kwarg vs varargs).
Wraps base_node.
- Parameters:
base_node – The expression to verify/permute.
indices – Tuple of dimensions to permute (e.g., (0, 2, 3, 1)).
semantics – Manager to look up permute_dims syntax.
target_fw – Target framework key.
- Returns:
CST Node representing permute(base_node, indices).