ml_switcheroo.plugins.context_to_function_wrap ============================================== .. py:module:: ml_switcheroo.plugins.context_to_function_wrap .. autoapi-nested-parse:: Plugin for handling Context Manager rewriting. This module addresses the impedance mismatch between PyTorch's global state context managers (like `torch.no_grad()`) and JAX's functional, explicit gradient handling. It provides a transformation hook that: 1. Detects usage of context managers flagged with `context_to_function_wrap`. 2. Injects a `nullcontext` shim into the function preamble. 3. Rewrites the specific API call to use this shim, ensuring the `with ...:` block remains valid Python syntax while effectively disabling gradient tracking semantics. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.context_to_function_wrap.transform_context_manager Module Contents --------------- .. py:function:: transform_context_manager(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Plugin Hook: Transforms valid Source context managers into JAX-compatible shims. Triggers: Operations marked with `requires_plugin: "context_to_function_wrap"` in Semantic JSONs. Primarily targets `torch.no_grad` and `torch.enable_grad`. :param node: The original CST Call node (e.g., `torch.no_grad()`). :param ctx: The HookContext providing injection capabilities. :returns: The transformed CST Call node pointing to the injected shim.