utils ===== .. py:module:: utils .. autoapi-nested-parse:: Utility functions for Call Rewriting. This module provides helper functions for inspecting and transforming LibCST Call nodes. It handles structural tasks such as: * Detecting functional usage patterns (e.g. `layer.apply`). * Rewriting stateful calls. * Injecting and stripping keyword arguments generically. * Generating permutation/transpose calls based on semantic layout maps. Decoupling Logic: Logic regarding specific framework APIs (e.g., whether to use `permute` vs `transpose`) is delegated to the `SemanticsManager`, removing hardcoded framework checks. Functional unwrapping detection is driven by `StructuralTraits`. Functions --------- .. autoapisummary:: utils.is_functional_apply utils.rewrite_stateful_call utils.inject_kwarg utils.strip_kwarg utils.is_super_call utils.is_builtin utils.log_diff utils.compute_permutation utils.inject_permute_call Module Contents --------------- .. py:function:: is_functional_apply(node: libcst.Call, method_name: Optional[str] = 'apply') -> bool Detects if a call node matches the functional execution pattern (e.g. `obj.apply`). Driven by the `functional_execution_method` trait of the source framework. This genericizes detection to support Flax (`apply`), Haiku (`apply`), or custom frameworks (`call_fn`). :param node: The function call node to inspect. :type node: cst.Call :param method_name: The method name to look for. Defaults to "apply". If None, functional unwrapping is disabled. :type method_name: str, optional :returns: True if the call is a method matching the name. :rtype: bool .. py:function:: rewrite_stateful_call(rewriter: Any, node: libcst.Call, instance_name: str, config: Dict[str, str]) -> libcst.Call Rewrites a call to a stateful object to match a functional pattern. Used when converting OOP frameworks to Functional ones where state must be passed explicitly. Can inject arguments (e.g. `variables`) and change method names (e.g. `__call__` -> `apply`). :param rewriter: The Transformer instance (must expose context). :param node: The original call node. :type node: cst.Call :param instance_name: The name of the object instance being called. :type instance_name: str :param config: Configuration dict containing 'prepend_arg' and 'method'. :type config: Dict[str, str] :returns: The transformed call node. :rtype: cst.Call .. py:function:: inject_kwarg(node: libcst.Call, arg_name: str, val_name: str) -> libcst.Call Generic helper to inject a keyword argument into a call. Prevents duplication if the argument already exists. Format: `func(..., arg_name=val_name)` :param node: The call node to modify. :type node: cst.Call :param arg_name: The keyword argument name. :type arg_name: str :param val_name: The variable name to pass as value. :type val_name: str :returns: The updated call node with the injected argument (if not present). :rtype: cst.Call .. py:function:: strip_kwarg(node: libcst.Call, kw_name: str) -> libcst.Call Removes a specified keyword argument from a function call. :param node: The call node. :type node: cst.Call :param kw_name: The keyword string to strip. :type kw_name: str :returns: The updated node with the argument removed. :rtype: cst.Call .. py:function:: is_super_call(node: libcst.Call) -> bool Detects if a call is `super()` or `super().method()`. :param node: The call node. :type node: cst.Call :returns: True if it is a super call. :rtype: bool .. py:function:: is_builtin(name: str) -> bool Checks if a name corresponds to a standard Python builtin. Used to prevent excessive logging/tracing of standard language features. :param name: The function name. :type name: str :returns: True if builtin. :rtype: bool .. py:function:: log_diff(label: str, original: libcst.CSTNode, modified: libcst.CSTNode) -> None Helper to compute AST diffs and log them to the tracer if changes occurred. :param label: Label for the log entry. :type label: str :param original: The node before transformation. :type original: cst.CSTNode :param modified: The node after transformation. :type modified: cst.CSTNode .. py:function:: compute_permutation(source_layout: str, target_layout: str) -> Optional[Tuple[int, Ellipsis]] Computes permutation indices to transform source layout string to target layout string. .. rubric:: Example Source: "NCHW", Target: "NHWC" Map: N:0, C:1, H:2, W:3 Target Required: N(0), H(2), W(3), C(1) Result: (0, 2, 3, 1) :param source_layout: Source layout string (e.g. "NCHW"). :type source_layout: str :param target_layout: Target layout string (e.g. "NHWC"). :type target_layout: str :returns: Tuple of integer indices, or None if invalid. :rtype: tuple[int, ...] .. py:function:: inject_permute_call(base_node: libcst.CSTNode, indices: Tuple[int, Ellipsis], semantics: ml_switcheroo.semantics.manager.SemanticsManager, target_fw: str) -> libcst.CSTNode Wraps a CST node with a permutation call valid for the target framework. Decoupling Logic: It queries the SemanticsManager for the `permute_dims` definition in the target tier. It does NOT contain hardcoded framework checks (like "if target == torch"). If a definition is missing, it returns the bare node (No-Op), avoiding assumption of JAX-style syntax. :param base_node: The expression to wrap (the input tensor). :type base_node: cst.CSTNode :param indices: Tuple of dimensions to permute (e.g., (0, 2, 3, 1)). :type indices: Tuple[int, ...] :param semantics: Manager to look up syntax. :type semantics: SemanticsManager :param target_fw: Target framework key. :type target_fw: str :returns: Node representing `permute(base_node, indices)` or original if unsupported. :rtype: cst.CSTNode