ml_switcheroo.plugins.checkpoint_keys¶

Plugin for Checkpoint Key Remapping.

Handles the runtime impedance mismatch between PyTorch state dictionaries and Flax/JAX parameter trees.

Problem: 1. Naming: Torch uses layer.0.weight. Flax uses layer_0.kernel (or scale for BN). 2. Shapes: Torch Linear/Conv weights are usually (Out, In, …) or (Out, In).

Flax weights are (…, In, Out).

  1. Semantics: model.load_state_dict(d) mutates state in Torch. JAX requires returning a new tree.

Solution: This plugin converts model.load_state_dict(state) calls into usage of a generated runtime utility KeyMapper.from_torch(state).

It creates the dependency on the KeyMapper class, which the transpiler’s PreambleGenerator is expected to inject into the output file using the source code defined in KEY_MAPPER_SOURCE.

Attributes¶

KEY_MAPPER_SOURCE

Functions¶

transform_checkpoint_keys(→ libcst.CSTNode)

Hook: Transforms load_state_dict calls to KeyMapper usage.

Module Contents¶

ml_switcheroo.plugins.checkpoint_keys.KEY_MAPPER_SOURCE = Multiline-String¶
Show Value
"""
import jax.numpy as jnp
import numpy as np
import re

class KeyMapper:
    @staticmethod
    def map_name(name):
        # 1. Standardize replacements
        name = name.replace("weight", "kernel")
        name = name.replace("running_mean", "mean")
        name = name.replace("running_var", "var")
        # BN scale convention
        if "bn" in name or "norm" in name:
            name = name.replace("kernel", "scale")

        # 2. Separators: layer.0 -> layer_0
        parts = name.split(".")
        new_parts = []
        for i, p in enumerate(parts):
            if p.isdigit() and i > 0:
                # Merge with previous: layer.0 -> layer_0
                prev = new_parts.pop()
                new_parts.append(f"{prev}_{p}")
            else:
                new_parts.append(p)
        return ".".join(new_parts)

    @staticmethod
    def map_value(key, val):
        val = np.array(val)
        # Heuristic Transpose for Dense/Conv kernels
        # Torch Linear: (Out, In) -> JAX (In, Out)
        # Torch Conv2d: (Out, In, H, W) -> JAX (H, W, In, Out)

        if "weight" in key or "kernel" in key:
            if val.ndim == 2:
                # Linear
                val = val.transpose((1, 0))
            elif val.ndim == 4:
                # Conv2d: (O, I, H, W) -> (H, W, I, O)
                val = val.transpose((2, 3, 1, 0))
        return jnp.array(val)

    @classmethod
    def from_torch(cls, state_dict):
        # Flattened Torch dict -> Nested Flax dict (simplified)
        # In reality, Flax usually requires unfreezing a target tree and mapping,
        # but here we produce a flat dict with mapped keys/values for loading logic.
        new_dict = {}
        for k, v in state_dict.items():
            nk = cls.map_name(k)
            nv = cls.map_value(k, v)
            new_dict[nk] = nv
        return new_dict
"""
ml_switcheroo.plugins.checkpoint_keys.transform_checkpoint_keys(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode¶

Hook: Transforms load_state_dict calls to KeyMapper usage.

Transformation:

Input: model.load_state_dict(state, strict=True) Output: variables = KeyMapper.from_torch(state)

Note: In JAX/Flax, one does not simply ‘load’ into a model in-place. One gets a variable dict. This transformation implies the user variable model might need to be reassigned or the result assigned to a parameter variable.

Since we cannot infer the exact variable to assign to in a generic expression transform, we act on the expression itself. If used as model.load_state_dict(…), it becomes KeyMapper.from_torch(…).