ml_switcheroo.testing.harness_generator_templateยถ

Template for the Verification Harness.

This template is populated by HarnessGenerator. It accepts: - source/target paths - frameworks - hints configuration - fuzzer_implementation: The extracted source code of the fuzzer class.

Attributesยถ

HARNESS_TEMPLATE

Module Contentsยถ

ml_switcheroo.testing.harness_generator_template.HARNESS_TEMPLATE = Multiline-Stringยถ
Show Value
"""
import sys
import importlib.util
import inspect
import traceback
import random
import os
import re
import math
from typing import Any, List, Dict, Tuple, Optional, Union
import numpy as np

# --- HELPERS FOR STATE INJECTION ---

def _make_jax_key(seed):
    "Attempts to create a JAX PRNGKey."
    try:
        import jax
        import jax.random
        return jax.random.PRNGKey(seed)
    except ImportError:
        return "mock_jax_key"

def _make_flax_rngs(seed):
    "Attempts to create a Flax NNX Rngs object."
    try:
        from flax import nnx
        return nnx.Rngs(seed)
    except ImportError:
        return "mock_flax_rngs"

# --- INJECTED FUZZER LOGIC ---

{fuzzer_implementation}

# -----------------------------

def to_numpy(obj):
    try:
        import numpy as np
    except ImportError:
        return obj

    if hasattr(obj, "detach"): # Torch
        return obj.detach().cpu().numpy()
    if hasattr(obj, "__array__"): # JAX/Numpy
        return np.array(obj)
    if hasattr(obj, "numpy"): # TF
        return obj.numpy()
    if isinstance(obj, (list, tuple)):
         return [to_numpy(x) for x in obj]
    return obj

# --- HARNESS LOGIC ---

def load_module_from_path(name, path):
    spec = importlib.util.spec_from_file_location(name, path)
    if spec is None:
        print(f"โŒ Could not load module {{name}} at {{path}}")
        return None
    mod = importlib.util.module_from_spec(spec)
    sys.modules[name] = mod
    try:
        spec.loader.exec_module(mod)
    except Exception as e:
        print(f"โŒ Error during import of {{name}}: {{e}}")
        return None
    return mod

def run_verification(source_path, target_path, source_fw, target_fw, hints_json_str):
    print(f"๐Ÿงช Verifying: {{source_path}} vs {{target_path}}")

    mod_src = load_module_from_path("mod_src", source_path)
    mod_tgt = load_module_from_path("mod_tgt", target_path)

    if not mod_src or not mod_tgt:
        print("โŒ Module loading failed. Aborting.")
        sys.exit(1)

    all_hints = {{}}
    if hints_json_str:
        try:
            import json
            all_hints = json.loads(hints_json_str)
        except:
            pass

    # Use the injected class alias
    fuzzer = StandaloneFuzzer()
    functions = inspect.getmembers(mod_src, inspect.isfunction)

    passes = 0
    failures = 0
    skips = 0

    import numpy as np

    for func_name, src_func in functions:
        if func_name.startswith("_"):
            continue

        if not hasattr(mod_tgt, func_name):
            print(f"โš ๏ธ  Skipping {{func_name}}: Not found in target module.")
            skips += 1
            continue

        tgt_func = getattr(mod_tgt, func_name)

        try:
            sig = inspect.signature(src_func)
            params = list(sig.parameters.keys())
        except ValueError:
            print(f"โš ๏ธ  Skipping {{func_name}}: Could not inspect signature.")
            skips += 1
            continue

        func_hints = all_hints.get(func_name, {{}})
        inputs = fuzzer.generate_inputs(params, hints=func_hints)

        # Adapt Inputs
        try:
            src_inputs = fuzzer.adapt_to_framework(inputs, source_fw)
            tgt_inputs = fuzzer.adapt_to_framework(inputs, target_fw)
        except Exception as e:
            print(f"โš ๏ธ  Skipping {{func_name}}: Input adaptation failed ({{e}})")
            # traceback.print_exc()
            skips += 1
            continue

        # --- AUTO-INJECTION ---
        try:
            tgt_sig = inspect.signature(tgt_func)
            tgt_params = list(tgt_sig.parameters.keys())

            for tp in tgt_params:
                if tp not in tgt_inputs:
                    val = None
                    if tp in ["rng", "key"]:
                        val = _make_jax_key(seed=42)
                    elif tp == "rngs":
                        val = _make_flax_rngs(seed=42)

                    if val is not None:
                        tgt_inputs[tp] = val
        except ValueError:
            pass

        # Execute
        try:
            res_src = src_func(**src_inputs)
            res_tgt = tgt_func(**tgt_inputs)

            val_src = to_numpy(res_src)
            val_tgt = to_numpy(res_tgt)

            def deep_compare(a, b):
                if hasattr(a, "shape") and hasattr(b, "shape"):
                     # Numpy like
                     return np.allclose(a, b, rtol=1e-3, atol=1e-4)
                if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
                    if len(a) != len(b): return False
                    return all(deep_compare(x, y) for x, y in zip(a, b))
                return a == b

            if deep_compare(val_src, val_tgt):
                print(f"โœ… {{func_name}}: Match")
                passes += 1
            else:
                print(f"โŒ {{func_name}}: Mismatch")
                failures += 1

        except Exception as e:
            print(f"โŒ {{func_name}}: Runtime Error ({{e}})")
            # traceback.print_exc()
            failures += 1

    print("-" * 30)
    print(f"๐Ÿ“Š Summary: {{passes}} Passed, {{failures}} Failed, {{skips}} Skipped")
    if failures > 0:
        sys.exit(1)
    sys.exit(0)

if __name__ == "__main__":
    run_verification(
        source_path=r"{source_path}",
        target_path=r"{target_path}",
        source_fw="{source_fw}",
        target_fw="{target_fw}",
        hints_json_str=r'{hints_json}'
    )
"""