ml_switcheroo.plugins.context_to_function_wrap

Plugin for handling Context Manager rewriting.

This module addresses the impedance mismatch between PyTorch’s global state context managers (like torch.no_grad()) and JAX’s functional, explicit gradient handling.

It provides a transformation hook that: 1. Detects usage of context managers flagged with context_to_function_wrap. 2. Injects a nullcontext shim into the function preamble. 3. Rewrites the specific API call to use this shim, ensuring the with …: block

remains valid Python syntax while effectively disabling gradient tracking semantics.

Functions

transform_context_manager(→ libcst.Call)

Plugin Hook: Transforms valid Source context managers into JAX-compatible shims.

Module Contents

ml_switcheroo.plugins.context_to_function_wrap.transform_context_manager(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Plugin Hook: Transforms valid Source context managers into JAX-compatible shims.

Triggers:

Operations marked with requires_plugin: “context_to_function_wrap” in Semantic JSONs. Primarily targets torch.no_grad and torch.enable_grad.

Parameters:
  • node – The original CST Call node (e.g., torch.no_grad()).

  • ctx – The HookContext providing injection capabilities.

Returns:

The transformed CST Call node pointing to the injected shim.