ml_switcheroo.plugins.checkpoint_keys¶
Plugin for Checkpoint Key Remapping.
Handles the runtime impedance mismatch between PyTorch state dictionaries and Flax/JAX parameter trees.
Problem: 1. Naming: Torch uses layer.0.weight. Flax uses layer_0.kernel (or scale for BN). 2. Shapes: Torch Linear/Conv weights are usually (Out, In, …) or (Out, In).
Flax weights are (…, In, Out).
Semantics: model.load_state_dict(d) mutates state in Torch. JAX requires returning a new tree.
Solution: This plugin converts model.load_state_dict(state) calls into usage of a generated runtime utility KeyMapper.from_torch(state).
It creates the dependency on the KeyMapper class, which the transpiler’s PreambleGenerator is expected to inject into the output file using the source code defined in KEY_MAPPER_SOURCE.
Attributes¶
Functions¶
|
Hook: Transforms load_state_dict calls to KeyMapper usage. |
Module Contents¶
- ml_switcheroo.plugins.checkpoint_keys.KEY_MAPPER_SOURCE = Multiline-String¶
Show Value
""" import jax.numpy as jnp import numpy as np import re class KeyMapper: @staticmethod def map_name(name): # 1. Standardize replacements name = name.replace("weight", "kernel") name = name.replace("running_mean", "mean") name = name.replace("running_var", "var") # BN scale convention if "bn" in name or "norm" in name: name = name.replace("kernel", "scale") # 2. Separators: layer.0 -> layer_0 parts = name.split(".") new_parts = [] for i, p in enumerate(parts): if p.isdigit() and i > 0: # Merge with previous: layer.0 -> layer_0 prev = new_parts.pop() new_parts.append(f"{prev}_{p}") else: new_parts.append(p) return ".".join(new_parts) @staticmethod def map_value(key, val): val = np.array(val) # Heuristic Transpose for Dense/Conv kernels # Torch Linear: (Out, In) -> JAX (In, Out) # Torch Conv2d: (Out, In, H, W) -> JAX (H, W, In, Out) if "weight" in key or "kernel" in key: if val.ndim == 2: # Linear val = val.transpose((1, 0)) elif val.ndim == 4: # Conv2d: (O, I, H, W) -> (H, W, I, O) val = val.transpose((2, 3, 1, 0)) return jnp.array(val) @classmethod def from_torch(cls, state_dict): # Flattened Torch dict -> Nested Flax dict (simplified) # In reality, Flax usually requires unfreezing a target tree and mapping, # but here we produce a flat dict with mapped keys/values for loading logic. new_dict = {} for k, v in state_dict.items(): nk = cls.map_name(k) nv = cls.map_value(k, v) new_dict[nk] = nv return new_dict """
- ml_switcheroo.plugins.checkpoint_keys.transform_checkpoint_keys(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Hook: Transforms load_state_dict calls to KeyMapper usage.
- Transformation:
Input: model.load_state_dict(state, strict=True) Output: variables = KeyMapper.from_torch(state)
Note: In JAX/Flax, one does not simply ‘load’ into a model in-place. One gets a variable dict. This transformation implies the user variable model might need to be reassigned or the result assigned to a parameter variable.
Since we cannot infer the exact variable to assign to in a generic expression transform, we act on the expression itself. If used as model.load_state_dict(…), it becomes KeyMapper.from_torch(…).