ml_switcheroo.plugins.batch_norm¶
Plugin for Functionalizing Batch Normalization (State Unwrapping).
Addresses the semantic mismatch between: 1. PyTorch (In-Place): y = bn(x). Updates running_mean/var attributes on bn silently. 2. JAX/Flax (Functional): y, new_state = bn(x, mutable=[‘batch_stats’]). State updates are returned explicitly.
This plugin: 1. Injects specific kwargs required by Flax (use_running_average, mutable). 2. Adapts the return value to fit into an expression context by selecting the output tensor [0].
Limitation: This plugin solves the Forward Pass compatibility. It effectively discards the updated state (new_state), effectively turning the layer into inference-mode regarding state persistence, unless the surrounding code is manually refactored to handle the tuple return. This is a necessary compromise to allow y = bn(x) to compile in JAX without extensive dataflow analysis of the entire training loop.
Functions¶
|
Hook: Wraps BatchNorm calls to handle functional state returns. |
Module Contents¶
- ml_switcheroo.plugins.batch_norm.transform_batch_norm(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Hook: Wraps BatchNorm calls to handle functional state returns.
- Transformation:
Input: self.bn1(x) Output: self.bn1(x, use_running_average=not training, mutable=[‘batch_stats’])[0]
- Logic:
Mode Switching: Injects use_running_average=not training. This assumes a training boolean variable exists in the scope (commonly injected by src/ml_switcheroo/plugins/state_flag_injection.py).
Mutability: Injects mutable=[‘batch_stats’] to allow tracking stats during training.
Unwrapping: Applies [0] subscript to the result call. Flax returns (tensor, updates), we select tensor to maintain compatibility with operators expecting a single array (like relu).
- Parameters:
node – The original CST Call node.
ctx – Hook Context containing target framework metadata.
- Returns:
A CST Subscript node representing the tensor output of the BN call.