ml_switcheroo.plugins.batch_norm¶

Plugin for Functionalizing Batch Normalization (State Unwrapping).

Addresses the semantic mismatch between: 1. PyTorch (In-Place): y = bn(x). Updates running_mean/var attributes on bn silently. 2. JAX/Flax (Functional): y, new_state = bn(x, mutable=[‘batch_stats’]). State updates are returned explicitly.

This plugin: 1. Injects specific kwargs required by Flax (use_running_average, mutable). 2. Adapts the return value to fit into an expression context by selecting the output tensor [0].

Limitation: This plugin solves the Forward Pass compatibility. It effectively discards the updated state (new_state), effectively turning the layer into inference-mode regarding state persistence, unless the surrounding code is manually refactored to handle the tuple return. This is a necessary compromise to allow y = bn(x) to compile in JAX without extensive dataflow analysis of the entire training loop.

Functions¶

transform_batch_norm(→ libcst.CSTNode)

Hook: Wraps BatchNorm calls to handle functional state returns.

Module Contents¶

ml_switcheroo.plugins.batch_norm.transform_batch_norm(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode¶

Hook: Wraps BatchNorm calls to handle functional state returns.

Transformation:

Input: self.bn1(x) Output: self.bn1(x, use_running_average=not training, mutable=[‘batch_stats’])[0]

Logic:
  1. Mode Switching: Injects use_running_average=not training. This assumes a training boolean variable exists in the scope (commonly injected by src/ml_switcheroo/plugins/state_flag_injection.py).

  2. Mutability: Injects mutable=[‘batch_stats’] to allow tracking stats during training.

  3. Unwrapping: Applies [0] subscript to the result call. Flax returns (tensor, updates), we select tensor to maintain compatibility with operators expecting a single array (like relu).

Parameters:
  • node – The original CST Call node.

  • ctx – Hook Context containing target framework metadata.

Returns:

A CST Subscript node representing the tensor output of the BN call.