ml_switcheroo.core.rewriter.calls.utils ======================================= .. py:module:: ml_switcheroo.core.rewriter.calls.utils .. autoapi-nested-parse:: Utility functions for Call Rewriting. Validates function patterns and handles structural injections like state objects, shims, and layout permutations. Functions --------- .. autoapisummary:: ml_switcheroo.core.rewriter.calls.utils.is_functional_apply ml_switcheroo.core.rewriter.calls.utils.rewrite_stateful_call ml_switcheroo.core.rewriter.calls.utils.inject_rngs_kwarg ml_switcheroo.core.rewriter.calls.utils.strip_kwarg ml_switcheroo.core.rewriter.calls.utils.is_super_call ml_switcheroo.core.rewriter.calls.utils.is_builtin ml_switcheroo.core.rewriter.calls.utils.log_diff ml_switcheroo.core.rewriter.calls.utils.compute_permutation ml_switcheroo.core.rewriter.calls.utils.inject_custom_api_call ml_switcheroo.core.rewriter.calls.utils.inject_permute_call Module Contents --------------- .. py:function:: is_functional_apply(node: libcst.Call) -> bool Detects if a call node matches the `obj.apply` pattern used in Flax Linen. .. py:function:: rewrite_stateful_call(rewriter, node: libcst.Call, instance_name: str, config: Dict[str, str]) -> libcst.Call Rewrites a call to a stateful object (Functional patterns only). .. py:function:: inject_rngs_kwarg(node: libcst.Call) -> libcst.Call Injects `rngs=rngs` into a constructor call. .. py:function:: strip_kwarg(node: libcst.Call, kw_name: str) -> libcst.Call Removes a keyword argument from a call node. .. py:function:: is_super_call(node: libcst.Call) -> bool Helper to identify direct super() usage or super().__init__(). .. py:function:: is_builtin(name: str) -> bool Avoid spamming logs for standard python builtins unless mapped. .. py:function:: log_diff(label: str, original: libcst.CSTNode, modified: libcst.CSTNode) -> None Helper to compute diff and log if changed. .. py:function:: compute_permutation(source_layout: str, target_layout: str) -> Optional[Tuple[int, Ellipsis]] Computes permutation indices to transform source layout to target. .. rubric:: Example Source: "NCHW", Target: "NHWC" Map: N:0, C:1, H:2, W:3 Target Required: N(0), H(2), W(3), C(1) Result: (0, 2, 3, 1) :param source_layout: Source layout string (e.g. "NCHW"). :param target_layout: Target layout string (e.g. "NHWC"). :returns: Tuple of integer indices, or None if invalid. .. py:function:: inject_custom_api_call(func_name_node: libcst.BaseExpression, args: List[libcst.Arg]) -> libcst.Call Constructs a generic Call node. .. py:function:: inject_permute_call(base_node: libcst.CSTNode, indices: Tuple[int, Ellipsis], semantics: ml_switcheroo.semantics.manager.SemanticsManager, target_fw: str) -> libcst.CSTNode Wraps a CST node with a permutation call valid for the target framework. Logic: 1. Finds `permute_dims` definition in Semantics for the target framework. 2. Determines API name (e.g. `jnp.transpose`, `torch.permute`, `tf.transpose`). 3. Determines packing strategy (tuple kwarg vs varargs). 4. Wraps `base_node`. :param base_node: The expression to verify/permute. :param indices: Tuple of dimensions to permute (e.g., (0, 2, 3, 1)). :param semantics: Manager to look up `permute_dims` syntax. :param target_fw: Target framework key. :returns: CST Node representing `permute(base_node, indices)`.