ml_switcheroo.core.rewriter.initializers

Initializer Rewriter Mixin.

Handles the translation of PyTorch’s imperative initialization calls (e.g., torch.nn.init.kaiming_normal_(tensor)) into JAX/Flax style initializer factories (e.g., jax.nn.initializers.he_normal()).

Key Differences: 1. In-place vs Factory: Torch modifies tensors in-place. JAX inits are factories returning functions. 2. Signature: Torch takes (tensor, …) as the first arg. JAX factories take config args (…)

and return a function f(key, shape).

  1. Naming: standard variations (Kaiming->He, Xavier->Glorot).

Transformation Logic:

Input: nn.init.kaiming_uniform_(self.weight, a=0) Output: jax.nn.initializers.he_uniform(a=0)

Note: This usually leaves a “dangling” factory call in the AST if it was a standalone statement. A subsequent pass (Parameter Decl Refactor) is responsible for moving this factory into the self.param(…, kernel_init=HERE) definition. This Mixin focuses solely on correctly translating the API and arguments.

Attributes

INIT_NAME_MAP

Classes

InitializerMixin

Mixin for PivotRewriter to handle torch.nn.init calls.

Module Contents

ml_switcheroo.core.rewriter.initializers.INIT_NAME_MAP
class ml_switcheroo.core.rewriter.initializers.InitializerMixin

Mixin for PivotRewriter to handle torch.nn.init calls.

leave_Call(original_node: libcst.Call, updated_node: libcst.Call) libcst.Call

Detects initialization calls and standardizes them to JAX factories.