ml_switcheroo.core.rewriter.initializers¶
Initializer Rewriter Mixin.
Handles the translation of PyTorch’s imperative initialization calls (e.g., torch.nn.init.kaiming_normal_(tensor)) into JAX/Flax style initializer factories (e.g., jax.nn.initializers.he_normal()).
Key Differences: 1. In-place vs Factory: Torch modifies tensors in-place. JAX inits are factories returning functions. 2. Signature: Torch takes (tensor, …) as the first arg. JAX factories take config args (…)
and return a function f(key, shape).
Naming: standard variations (Kaiming->He, Xavier->Glorot).
- Transformation Logic:
Input: nn.init.kaiming_uniform_(self.weight, a=0) Output: jax.nn.initializers.he_uniform(a=0)
Note: This usually leaves a “dangling” factory call in the AST if it was a standalone statement. A subsequent pass (Parameter Decl Refactor) is responsible for moving this factory into the self.param(…, kernel_init=HERE) definition. This Mixin focuses solely on correctly translating the API and arguments.
Attributes¶
Classes¶
Mixin for PivotRewriter to handle torch.nn.init calls. |
Module Contents¶
- ml_switcheroo.core.rewriter.initializers.INIT_NAME_MAP¶