ml_switcheroo.plugins.keras_sequential¶
Plugin for Keras Sequential Container Translation.
This plugin bridges the difference between PyTorch’s nn.Sequential(*layers) (variadic args) and Keras’s keras.Sequential([layers]) (list input).
It performs two key transformations:
API Renaming: Swaps the function name to keras.Sequential (or configured API).
Argument Packing: Collects all positional arguments (individual layers) into a single list argument to match the Keras constructor signature.
Functions¶
|
Plugin Hook: Transforms Sequential container initialization. |
Module Contents¶
- ml_switcheroo.plugins.keras_sequential.transform_keras_sequential(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Plugin Hook: Transforms Sequential container initialization.
- Transformation:
Input: Sequential(layer1, layer2, …) Output: keras.Sequential([layer1, layer2, …])
This hook is triggered for operations mapped with requires_plugin=”keras_sequential_pack”. It uses ctx.lookup_api(“Sequential”) to determine the target class name, defaulting to keras.Sequential if lookup fails.
- Parameters:
node – The original function call node.
ctx – The plugin execution context.
- Returns:
The transformed function call.