ml_switcheroo.plugins.data_loader¶
Plugin for transforming Data Loaders.
Handles the mapping of torch.utils.data.DataLoader to: 1. GenericDataLoader shim (for JAX/NumPy). 2. Native torch.utils.data.DataLoader (Pass-through for Torch).
Implementation Details: - Filters arguments to ensure compatibility with the Generic Shim. - Explicitly handles num_workers, pin_memory, and drop_last by mapping
names and passing them (since the Shim now accepts them as optional kwargs).
Injects the Shim class definition at the top of the file on first use.
Functions¶
|
Middleware to rewrite DataLoader instantiation. |
Module Contents¶
- ml_switcheroo.plugins.data_loader.transform_dataloader(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Middleware to rewrite DataLoader instantiation.
- Triggers:
Operations marked with requires_plugin: “convert_dataloader” (e.g., DataLoader in k_framework_extras.json).
- Strategy:
If Target == Torch: Keep as is.
If Target != Torch: Inject Shim logic and rewrite to GenericDataLoader.
Maps performance args (num_workers, pin_memory) to the shim signatures where they are safely ignored.