ml_switcheroo.plugins.checkpoint_keys¶
Plugin for Checkpoint Key Remapping.
Handles the runtime impedance mismatch between PyTorch state dictionaries and flattened parameter trees.
Logic: This plugin converts load_state_dict(state) calls into usage of a generated runtime utility KeyMapper.from_torch(state).
Decoupling Logic: The injected KeyMapper utility now outputs NumPy arrays. This ensures compatibility with JAX, TensorFlow, MLX, and others without forcing a hard dependency on jax in the generated code.
Attributes¶
Functions¶
|
Hook: Transforms load_state_dict calls to KeyMapper usage. |
Module Contents¶
- ml_switcheroo.plugins.checkpoint_keys.KEY_MAPPER_SOURCE = Multiline-String¶
Show Value
""" import numpy as np import re class KeyMapper: @staticmethod def map_name(name): # 1. Standardize replacements name = name.replace("weight", "kernel") name = name.replace("running_mean", "mean") name = name.replace("running_var", "var") # BN scale convention if "bn" in name or "norm" in name: name = name.replace("kernel", "scale") # 2. Separators: layer.0 -> layer_0 parts = name.split(".") new_parts = [] for i, p in enumerate(parts): if p.isdigit() and i > 0: # Merge with previous: layer.0 -> layer_0 prev = new_parts.pop() new_parts.append(f"{prev}_{p}") else: new_parts.append(p) return ".".join(new_parts) @staticmethod def map_value(key, val): try: # Convert Torch tensors or other formats to numpy if hasattr(val, 'cpu'): val = val.detach().cpu().numpy() else: val = np.array(val) except: return val # Fallback # Heuristic Transpose for Dense/Conv kernels if "weight" in key or "kernel" in key: if val.ndim == 2: # Linear: (Out, In) -> (In, Out) val = val.transpose((1, 0)) elif val.ndim == 4: # Conv2d: (O, I, H, W) -> (H, W, I, O) val = val.transpose((2, 3, 1, 0)) # Return as numpy array. Target frameworks (JAX, TF, etc) handle numpy inputs. return val @classmethod def from_torch(cls, state_dict): new_dict = {} for k, v in state_dict.items(): nk = cls.map_name(k) nv = cls.map_value(k, v) new_dict[nk] = nv return new_dict """
- ml_switcheroo.plugins.checkpoint_keys.transform_checkpoint_keys(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode[source]¶
Hook: Transforms load_state_dict calls to KeyMapper usage.
Transformation model.load_state_dict(state, strict=True) -> KeyMapper.from_torch(state)
Triggers if mapped via requires_plugin=âcheckpoint_mapperâ. Injects KeyMapper source code once per file.