ml_switcheroo.plugins.batch_norm¶

Plugin for Functionalizing Batch Normalization (State Unwrapping).

Addresses the semantic mismatch between: 1. PyTorch (In-Place): y = bn(x). Updates running_mean/var attributes on bn silently. 2. Functional Frameworks (JAX): y, new_state = bn(x, mutable=[‘batch_stats’]).

This plugin: 1. Injects specific kwargs (use_running_average, mutable). 2. Adapts the return value to fit into an expression context by selecting the output tensor [0].

Decoupling Update: Checks traits.requires_functional_state instead of hardcoding target frameworks.

Functions¶

transform_batch_norm(→ libcst.CSTNode)

Hook: Wraps BatchNorm calls to handle functional state returns.

Module Contents¶

ml_switcheroo.plugins.batch_norm.transform_batch_norm(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode[source]¶

Hook: Wraps BatchNorm calls to handle functional state returns.

Transformation:

Input: self.bn1(x) Output: self.bn1(x, use_running_average=not training, mutable=[‘batch_stats’])[0]

Logic:
  1. Capability Check: Verifies if the target framework requires functional state processing via plugin_traits.requires_functional_state.

  2. Mode Switching: Injects use_running_average=not training.

  3. Mutability: Injects mutable=[‘batch_stats’].

  4. Unwrapping: Applies [0] subscript to return just the tensor.

Parameters:
  • node – The original CST Call node.

  • ctx – Hook Context containing target framework metadata.

Returns:

A CST Subscript node representing the tensor output of the BN call.