ml_switcheroo.plugins.context_to_function_wrap¶

Plugin for handling Context Manager rewriting.

This module addresses the impedance mismatch between PyTorch’s global state context managers (like torch.no_grad()) and JAX’s functional, explicit gradient handling.

It provides a transformation hook that:

  1. Detects usage of context managers flagged with context_to_function_wrap.

  2. Injects a nullcontext shim into the function preamble.

  3. Rewrites the specific API call to use this shim, ensuring the with …: block remains valid Python syntax while effectively disabling gradient tracking semantics.

Functions¶

transform_context_manager(→ libcst.Call)

Plugin Hook: Transforms valid Source context managers into JAX-compatible shims.

Module Contents¶

ml_switcheroo.plugins.context_to_function_wrap.transform_context_manager(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call[source]¶

Plugin Hook: Transforms valid Source context managers into JAX-compatible shims.

Triggers

  • Operations marked with requires_plugin: "context_to_function_wrap" in Semantic JSONs.

  • Primarily targets torch.no_grad and torch.enable_grad.

Parameters:
  • node – The original CST Call node (e.g., torch.no_grad()).

  • ctx – The HookContext providing injection capabilities.

Returns:

The transformed CST Call node pointing to the injected shim.