ml_switcheroo.plugins.checkpoint_keys¶

Plugin for Checkpoint Key Remapping.

Handles the runtime impedance mismatch between PyTorch state dictionaries and flattened parameter trees.

Logic: This plugin converts load_state_dict(state) calls into usage of a generated runtime utility KeyMapper.from_torch(state).

Decoupling Logic: The injected KeyMapper utility now outputs NumPy arrays. This ensures compatibility with JAX, TensorFlow, MLX, and others without forcing a hard dependency on jax in the generated code.

Attributes¶

KEY_MAPPER_SOURCE

Functions¶

transform_checkpoint_keys(→ libcst.CSTNode)

Hook: Transforms load_state_dict calls to KeyMapper usage.

Module Contents¶

ml_switcheroo.plugins.checkpoint_keys.KEY_MAPPER_SOURCE = Multiline-String¶
Show Value
"""
import numpy as np
import re

class KeyMapper:
  @staticmethod
  def map_name(name):
    # 1. Standardize replacements
    name = name.replace("weight", "kernel")
    name = name.replace("running_mean", "mean")
    name = name.replace("running_var", "var")
    # BN scale convention
    if "bn" in name or "norm" in name:
      name = name.replace("kernel", "scale")

      # 2. Separators: layer.0 -> layer_0
      parts = name.split(".")
      new_parts = []
      for i, p in enumerate(parts):
        if p.isdigit() and i > 0:
          # Merge with previous: layer.0 -> layer_0
          prev = new_parts.pop()
          new_parts.append(f"{prev}_{p}")
        else:
          new_parts.append(p)
      return ".".join(new_parts)

  @staticmethod
  def map_value(key, val):
    try:
        # Convert Torch tensors or other formats to numpy
        if hasattr(val, 'cpu'):
            val = val.detach().cpu().numpy()
        else:
            val = np.array(val)
    except:
        return val # Fallback

    # Heuristic Transpose for Dense/Conv kernels
    if "weight" in key or "kernel" in key:
      if val.ndim == 2:
        # Linear: (Out, In) -> (In, Out)
        val = val.transpose((1, 0))
      elif val.ndim == 4:
        # Conv2d: (O, I, H, W) -> (H, W, I, O)
        val = val.transpose((2, 3, 1, 0))

    # Return as numpy array. Target frameworks (JAX, TF, etc) handle numpy inputs.
    return val

  @classmethod
  def from_torch(cls, state_dict):
      new_dict = {}
      for k, v in state_dict.items():
        nk = cls.map_name(k)
        nv = cls.map_value(k, v)
        new_dict[nk] = nv
      return new_dict
"""
ml_switcheroo.plugins.checkpoint_keys.transform_checkpoint_keys(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.CSTNode[source]¶

Hook: Transforms load_state_dict calls to KeyMapper usage.

Transformation model.load_state_dict(state, strict=True) -> KeyMapper.from_torch(state)

Triggers if mapped via requires_plugin=”checkpoint_mapper”. Injects KeyMapper source code once per file.