ml_switcheroo.plugins.data_loader

Plugin for transforming Data Loaders.

Handles the mapping of torch.utils.data.DataLoader to: 1. GenericDataLoader shim (for JAX/NumPy). 2. Native torch.utils.data.DataLoader (Pass-through for Torch).

Implementation Details: - Filters arguments to ensure compatibility with the Generic Shim. - Explicitly handles num_workers, pin_memory, and drop_last by mapping

names and passing them (since the Shim now accepts them as optional kwargs).

  • Injects the Shim class definition at the top of the file on first use.

Functions

transform_dataloader(→ libcst.CSTNode)

Middleware to rewrite DataLoader instantiation.

Module Contents

ml_switcheroo.plugins.data_loader.transform_dataloader(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode

Middleware to rewrite DataLoader instantiation.

Triggers:

Operations marked with requires_plugin: “convert_dataloader” (e.g., DataLoader in k_framework_extras.json).

Strategy:
  • If Target == Torch: Keep as is.

  • If Target != Torch: Inject Shim logic and rewrite to GenericDataLoader.

  • Maps performance args (num_workers, pin_memory) to the shim signatures where they are safely ignored.