ml_switcheroo.plugins.tf_data_loader¶

Plugin for Native TensorFlow Data Pipeline Generation.

This module converts generic DataLoader usage (typically from PyTorch) into native tf.data.Dataset pipelines. It performs significantly more structural changes than the generic shim, rewriting the iterator construction into a functional method chain.

Transformation Overview:

Input: DataLoader(TensorDataset(x, y), batch_size=64, shuffle=True) Output: tf.data.Dataset.from_tensor_slices((x, y)).shuffle(1024).batch(64).prefetch(AUTOTUNE)

Functions¶

transform_tf_dataloader(→ Union[libcst.Call, ...)

Plugin Hook: Rewrites DataLoader construction into a tf.data pipeline.

Module Contents¶

ml_switcheroo.plugins.tf_data_loader.transform_tf_dataloader(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call | libcst.FlattenSentinel[source]¶

Plugin Hook: Rewrites DataLoader construction into a tf.data pipeline.

Logic: 1. Extracts the dataset argument (assumed to be position 0). 2. Unwraps TensorDataset calls to get raw tensors. 3. Constructs tf.data.Dataset.from_tensor_slices(…). 4. Chains .shuffle(1024) if shuffle=True. 5. Chains .batch(…) based on batch_size arg or defaults to 1. 6. Chains .prefetch(tf.data.AUTOTUNE) for performance optimization.

Parameters:
  • node – The original DataLoader call node.

  • ctx – HookContext (unused for logic but required by protocol).

Returns:

The transformed method chain representing the TF Dataset.