ml_switcheroo.plugins.context_to_function_wrap¶
Plugin for handling Context Manager rewriting.
This module addresses the impedance mismatch between PyTorch’s global state context managers (like torch.no_grad()) and JAX’s functional, explicit gradient handling.
It provides a transformation hook that: 1. Detects usage of context managers flagged with context_to_function_wrap. 2. Injects a nullcontext shim into the function preamble. 3. Rewrites the specific API call to use this shim, ensuring the with …: block
remains valid Python syntax while effectively disabling gradient tracking semantics.
Functions¶
|
Plugin Hook: Transforms valid Source context managers into JAX-compatible shims. |
Module Contents¶
- ml_switcheroo.plugins.context_to_function_wrap.transform_context_manager(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Plugin Hook: Transforms valid Source context managers into JAX-compatible shims.
- Triggers:
Operations marked with requires_plugin: “context_to_function_wrap” in Semantic JSONs. Primarily targets torch.no_grad and torch.enable_grad.
- Parameters:
node – The original CST Call node (e.g., torch.no_grad()).
ctx – The HookContext providing injection capabilities.
- Returns:
The transformed CST Call node pointing to the injected shim.