ml_switcheroo.core.rewriter.initializers ======================================== .. py:module:: ml_switcheroo.core.rewriter.initializers .. autoapi-nested-parse:: Initializer Rewriter Mixin. Handles the translation of PyTorch's imperative initialization calls (e.g., `torch.nn.init.kaiming_normal_(tensor)`) into JAX/Flax style initializer factories (e.g., `jax.nn.initializers.he_normal()`). Key Differences: 1. **In-place vs Factory**: Torch modifies tensors in-place. JAX inits are factories returning functions. 2. **Signature**: Torch takes `(tensor, ...)` as the first arg. JAX factories take config args `(...)` and return a function `f(key, shape)`. 3. **Naming**: standard variations (Kaiming->He, Xavier->Glorot). Transformation Logic: Input: `nn.init.kaiming_uniform_(self.weight, a=0)` Output: `jax.nn.initializers.he_uniform(a=0)` *Note*: This usually leaves a "dangling" factory call in the AST if it was a standalone statement. A subsequent pass (Parameter Decl Refactor) is responsible for moving this factory into the `self.param(..., kernel_init=HERE)` definition. This Mixin focuses solely on correctly translating the API and arguments. Attributes ---------- .. autoapisummary:: ml_switcheroo.core.rewriter.initializers.INIT_NAME_MAP Classes ------- .. autoapisummary:: ml_switcheroo.core.rewriter.initializers.InitializerMixin Module Contents --------------- .. py:data:: INIT_NAME_MAP .. py:class:: InitializerMixin Mixin for PivotRewriter to handle `torch.nn.init` calls. .. py:method:: leave_Call(original_node: libcst.Call, updated_node: libcst.Call) -> libcst.Call Detects initialization calls and standardizes them to JAX factories.