utils

Utility functions for Call Rewriting.

This module provides helper functions for inspecting and transforming LibCST Call nodes. It handles structural tasks such as:

  • Detecting functional usage patterns (e.g. layer.apply).

  • Rewriting stateful calls.

  • Injecting and stripping keyword arguments generically.

  • Generating permutation/transpose calls based on semantic layout maps.

Decoupling Logic:

Logic regarding specific framework APIs (e.g., whether to use permute vs transpose) is delegated to the SemanticsManager, removing hardcoded framework checks. Functional unwrapping detection is driven by StructuralTraits.

Functions

is_functional_apply(→ bool)

Detects if a call node matches the functional execution pattern (e.g. obj.apply).

rewrite_stateful_call(→ libcst.Call)

Rewrites a call to a stateful object to match a functional pattern.

inject_kwarg(→ libcst.Call)

Generic helper to inject a keyword argument into a call.

strip_kwarg(→ libcst.Call)

Removes a specified keyword argument from a function call.

is_super_call(→ bool)

Detects if a call is super() or super().method().

is_builtin(→ bool)

Checks if a name corresponds to a standard Python builtin.

log_diff(→ None)

Helper to compute AST diffs and log them to the tracer if changes occurred.

compute_permutation(→ Optional[Tuple[int, Ellipsis]])

Computes permutation indices to transform source layout string to target layout string.

inject_permute_call(→ libcst.CSTNode)

Wraps a CST node with a permutation call valid for the target framework.

Module Contents

utils.is_functional_apply(node: libcst.Call, method_name: str | None = 'apply') bool

Detects if a call node matches the functional execution pattern (e.g. obj.apply).

Driven by the functional_execution_method trait of the source framework. This genericizes detection to support Flax (apply), Haiku (apply), or custom frameworks (call_fn).

Parameters:
  • node (cst.Call) – The function call node to inspect.

  • method_name (str, optional) – The method name to look for. Defaults to “apply”. If None, functional unwrapping is disabled.

Returns:

True if the call is a method matching the name.

Return type:

bool

utils.rewrite_stateful_call(rewriter: Any, node: libcst.Call, instance_name: str, config: Dict[str, str]) libcst.Call

Rewrites a call to a stateful object to match a functional pattern.

Used when converting OOP frameworks to Functional ones where state must be passed explicitly. Can inject arguments (e.g. variables) and change method names (e.g. __call__ -> apply).

Parameters:
  • rewriter – The Transformer instance (must expose context).

  • node (cst.Call) – The original call node.

  • instance_name (str) – The name of the object instance being called.

  • config (Dict[str, str]) – Configuration dict containing ‘prepend_arg’ and ‘method’.

Returns:

The transformed call node.

Return type:

cst.Call

utils.inject_kwarg(node: libcst.Call, arg_name: str, val_name: str) libcst.Call

Generic helper to inject a keyword argument into a call. Prevents duplication if the argument already exists.

Format: func(…, arg_name=val_name)

Parameters:
  • node (cst.Call) – The call node to modify.

  • arg_name (str) – The keyword argument name.

  • val_name (str) – The variable name to pass as value.

Returns:

The updated call node with the injected argument (if not present).

Return type:

cst.Call

utils.strip_kwarg(node: libcst.Call, kw_name: str) libcst.Call

Removes a specified keyword argument from a function call.

Parameters:
  • node (cst.Call) – The call node.

  • kw_name (str) – The keyword string to strip.

Returns:

The updated node with the argument removed.

Return type:

cst.Call

utils.is_super_call(node: libcst.Call) bool

Detects if a call is super() or super().method().

Parameters:

node (cst.Call) – The call node.

Returns:

True if it is a super call.

Return type:

bool

utils.is_builtin(name: str) bool

Checks if a name corresponds to a standard Python builtin. Used to prevent excessive logging/tracing of standard language features.

Parameters:

name (str) – The function name.

Returns:

True if builtin.

Return type:

bool

utils.log_diff(label: str, original: libcst.CSTNode, modified: libcst.CSTNode) None

Helper to compute AST diffs and log them to the tracer if changes occurred.

Parameters:
  • label (str) – Label for the log entry.

  • original (cst.CSTNode) – The node before transformation.

  • modified (cst.CSTNode) – The node after transformation.

utils.compute_permutation(source_layout: str, target_layout: str) Tuple[int, Ellipsis] | None

Computes permutation indices to transform source layout string to target layout string.

Example

Source: “NCHW”, Target: “NHWC” Map: N:0, C:1, H:2, W:3 Target Required: N(0), H(2), W(3), C(1) Result: (0, 2, 3, 1)

Parameters:
  • source_layout (str) – Source layout string (e.g. “NCHW”).

  • target_layout (str) – Target layout string (e.g. “NHWC”).

Returns:

Tuple of integer indices, or None if invalid.

Return type:

tuple[int, …]

utils.inject_permute_call(base_node: libcst.CSTNode, indices: Tuple[int, Ellipsis], semantics: ml_switcheroo.semantics.manager.SemanticsManager, target_fw: str) libcst.CSTNode

Wraps a CST node with a permutation call valid for the target framework.

Decoupling Logic:

It queries the SemanticsManager for the permute_dims definition in the target tier. It does NOT contain hardcoded framework checks (like “if target == torch”). If a definition is missing, it returns the bare node (No-Op), avoiding assumption of JAX-style syntax.

Parameters:
  • base_node (cst.CSTNode) – The expression to wrap (the input tensor).

  • indices (Tuple[int, ...]) – Tuple of dimensions to permute (e.g., (0, 2, 3, 1)).

  • semantics (SemanticsManager) – Manager to look up syntax.

  • target_fw (str) – Target framework key.

Returns:

Node representing permute(base_node, indices) or original if unsupported.

Return type:

cst.CSTNode