ml_switcheroo.plugins.batch_norm¶
Plugin for Functionalizing Batch Normalization (State Unwrapping).
Addresses the semantic mismatch between: 1. PyTorch (In-Place): y = bn(x). Updates running_mean/var attributes on bn silently. 2. Functional Frameworks (JAX): y, new_state = bn(x, mutable=[‘batch_stats’]).
This plugin: 1. Injects specific kwargs (use_running_average, mutable). 2. Adapts the return value to fit into an expression context by selecting the output tensor [0].
Decoupling Update: Checks traits.requires_functional_state instead of hardcoding target frameworks.
Functions¶
|
Hook: Wraps BatchNorm calls to handle functional state returns. |
Module Contents¶
- ml_switcheroo.plugins.batch_norm.transform_batch_norm(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook: Wraps BatchNorm calls to handle functional state returns.
- Transformation:
Input: self.bn1(x) Output: self.bn1(x, use_running_average=not training, mutable=[‘batch_stats’])[0]
- Logic:
Capability Check: Verifies if the target framework requires functional state processing via plugin_traits.requires_functional_state.
Mode Switching: Injects use_running_average=not training.
Mutability: Injects mutable=[‘batch_stats’].
Unwrapping: Applies [0] subscript to return just the tensor.
- Parameters:
node – The original CST Call node.
ctx – Hook Context containing target framework metadata.
- Returns:
A CST Subscript node representing the tensor output of the BN call.