utils¶
Utility functions for Call Rewriting.
This module provides helper functions for inspecting and transforming LibCST Call nodes. It handles structural tasks such as:
Detecting functional usage patterns (e.g. layer.apply).
Rewriting stateful calls.
Injecting and stripping keyword arguments generically.
Generating permutation/transpose calls based on semantic layout maps.
- Decoupling Logic:
Logic regarding specific framework APIs (e.g., whether to use permute vs transpose) is delegated to the SemanticsManager, removing hardcoded framework checks. Functional unwrapping detection is driven by StructuralTraits.
Functions¶
|
Detects if a call node matches the functional execution pattern (e.g. obj.apply). |
|
Rewrites a call to a stateful object to match a functional pattern. |
|
Generic helper to inject a keyword argument into a call. |
|
Removes a specified keyword argument from a function call. |
|
Detects if a call is super() or super().method(). |
|
Checks if a name corresponds to a standard Python builtin. |
|
Helper to compute AST diffs and log them to the tracer if changes occurred. |
|
Computes permutation indices to transform source layout string to target layout string. |
|
Wraps a CST node with a permutation call valid for the target framework. |
Module Contents¶
- utils.is_functional_apply(node: libcst.Call, method_name: str | None = 'apply') bool¶
Detects if a call node matches the functional execution pattern (e.g. obj.apply).
Driven by the functional_execution_method trait of the source framework. This genericizes detection to support Flax (apply), Haiku (apply), or custom frameworks (call_fn).
- Parameters:
node (cst.Call) – The function call node to inspect.
method_name (str, optional) – The method name to look for. Defaults to “apply”. If None, functional unwrapping is disabled.
- Returns:
True if the call is a method matching the name.
- Return type:
bool
- utils.rewrite_stateful_call(rewriter: Any, node: libcst.Call, instance_name: str, config: Dict[str, str]) libcst.Call¶
Rewrites a call to a stateful object to match a functional pattern.
Used when converting OOP frameworks to Functional ones where state must be passed explicitly. Can inject arguments (e.g. variables) and change method names (e.g. __call__ -> apply).
- Parameters:
rewriter – The Transformer instance (must expose context).
node (cst.Call) – The original call node.
instance_name (str) – The name of the object instance being called.
config (Dict[str, str]) – Configuration dict containing ‘prepend_arg’ and ‘method’.
- Returns:
The transformed call node.
- Return type:
cst.Call
- utils.inject_kwarg(node: libcst.Call, arg_name: str, val_name: str) libcst.Call¶
Generic helper to inject a keyword argument into a call. Prevents duplication if the argument already exists.
Format: func(…, arg_name=val_name)
- Parameters:
node (cst.Call) – The call node to modify.
arg_name (str) – The keyword argument name.
val_name (str) – The variable name to pass as value.
- Returns:
The updated call node with the injected argument (if not present).
- Return type:
cst.Call
- utils.strip_kwarg(node: libcst.Call, kw_name: str) libcst.Call¶
Removes a specified keyword argument from a function call.
- Parameters:
node (cst.Call) – The call node.
kw_name (str) – The keyword string to strip.
- Returns:
The updated node with the argument removed.
- Return type:
cst.Call
- utils.is_super_call(node: libcst.Call) bool¶
Detects if a call is super() or super().method().
- Parameters:
node (cst.Call) – The call node.
- Returns:
True if it is a super call.
- Return type:
bool
- utils.is_builtin(name: str) bool¶
Checks if a name corresponds to a standard Python builtin. Used to prevent excessive logging/tracing of standard language features.
- Parameters:
name (str) – The function name.
- Returns:
True if builtin.
- Return type:
bool
- utils.log_diff(label: str, original: libcst.CSTNode, modified: libcst.CSTNode) None¶
Helper to compute AST diffs and log them to the tracer if changes occurred.
- Parameters:
label (str) – Label for the log entry.
original (cst.CSTNode) – The node before transformation.
modified (cst.CSTNode) – The node after transformation.
- utils.compute_permutation(source_layout: str, target_layout: str) Tuple[int, Ellipsis] | None¶
Computes permutation indices to transform source layout string to target layout string.
Example
Source: “NCHW”, Target: “NHWC” Map: N:0, C:1, H:2, W:3 Target Required: N(0), H(2), W(3), C(1) Result: (0, 2, 3, 1)
- Parameters:
source_layout (str) – Source layout string (e.g. “NCHW”).
target_layout (str) – Target layout string (e.g. “NHWC”).
- Returns:
Tuple of integer indices, or None if invalid.
- Return type:
tuple[int, …]
- utils.inject_permute_call(base_node: libcst.CSTNode, indices: Tuple[int, Ellipsis], semantics: ml_switcheroo.semantics.manager.SemanticsManager, target_fw: str) libcst.CSTNode¶
Wraps a CST node with a permutation call valid for the target framework.
- Decoupling Logic:
It queries the SemanticsManager for the permute_dims definition in the target tier. It does NOT contain hardcoded framework checks (like “if target == torch”). If a definition is missing, it returns the bare node (No-Op), avoiding assumption of JAX-style syntax.
- Parameters:
base_node (cst.CSTNode) – The expression to wrap (the input tensor).
indices (Tuple[int, ...]) – Tuple of dimensions to permute (e.g., (0, 2, 3, 1)).
semantics (SemanticsManager) – Manager to look up syntax.
target_fw (str) – Target framework key.
- Returns:
Node representing permute(base_node, indices) or original if unsupported.
- Return type:
cst.CSTNode