ml_switcheroo.plugins.batch_norm ================================ .. py:module:: ml_switcheroo.plugins.batch_norm .. autoapi-nested-parse:: Plugin for Functionalizing Batch Normalization (State Unwrapping). Addresses the semantic mismatch between: 1. **PyTorch (In-Place)**: `y = bn(x)`. Updates `running_mean`/`var` attributes on `bn` silently. 2. **JAX/Flax (Functional)**: `y, new_state = bn(x, mutable=['batch_stats'])`. State updates are returned explicitly. This plugin: 1. Injects specific kwargs required by Flax (`use_running_average`, `mutable`). 2. Adapts the return value to fit into an expression context by selecting the output tensor `[0]`. **Limitation**: This plugin solves the *Forward Pass* compatibility. It effectively discards the updated state (`new_state`), effectively turning the layer into inference-mode regarding state persistence, unless the surrounding code is manually refactored to handle the tuple return. This is a necessary compromise to allow `y = bn(x)` to compile in JAX without extensive dataflow analysis of the entire training loop. Functions --------- .. autoapisummary:: ml_switcheroo.plugins.batch_norm.transform_batch_norm Module Contents --------------- .. py:function:: transform_batch_norm(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.CSTNode Hook: Wraps BatchNorm calls to handle functional state returns. Transformation: Input: `self.bn1(x)` Output: `self.bn1(x, use_running_average=not training, mutable=['batch_stats'])[0]` Logic: 1. **Mode Switching**: Injects `use_running_average=not training`. This assumes a `training` boolean variable exists in the scope (commonly injected by `src/ml_switcheroo/plugins/state_flag_injection.py`). 2. **Mutability**: Injects `mutable=['batch_stats']` to allow tracking stats during training. 3. **Unwrapping**: Applies `[0]` subscript to the result call. Flax returns `(tensor, updates)`, we select `tensor` to maintain compatibility with operators expecting a single array (like `relu`). :param node: The original CST Call node. :param ctx: Hook Context containing target framework metadata. :returns: A CST Subscript node representing the tensor output of the BN call.