ml_switcheroo.plugins.context_to_function_wrap¶
Plugin for handling Context Manager rewriting.
This module addresses the impedance mismatch between PyTorch’s global state context managers (like torch.no_grad()) and JAX’s functional, explicit gradient handling.
It provides a transformation hook that:
Detects usage of context managers flagged with context_to_function_wrap.
Injects a nullcontext shim into the function preamble.
Rewrites the specific API call to use this shim, ensuring the with …: block remains valid Python syntax while effectively disabling gradient tracking semantics.
Functions¶
|
Plugin Hook: Transforms valid Source context managers into JAX-compatible shims. |
Module Contents¶
- ml_switcheroo.plugins.context_to_function_wrap.transform_context_manager(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Plugin Hook: Transforms valid Source context managers into JAX-compatible shims.
Triggers
Operations marked with
requires_plugin: "context_to_function_wrap"in Semantic JSONs.Primarily targets
torch.no_gradandtorch.enable_grad.
- Parameters:
node – The original CST Call node (e.g.,
torch.no_grad()).ctx – The HookContext providing injection capabilities.
- Returns:
The transformed CST Call node pointing to the injected shim.