ml_switcheroo.generated_tests.runtime¶
Runtime helpers for generated verification tests.
This module contains the reference implementation of verify_results,
a robust recursive comparison utility for validating equivalence between
arbitrary data structures (Arrays, Lists, Dicts) across different frameworks.
It also defines the ensure_determinism fixture, which is injected into
generated tests to enforce reproducibility by seeding RNGs for Torch, JAX,
TensorFlow, Numpy, and Python.
Functions¶
|
Auto-injects fixed seeds for reproducibility at the start of every test. |
|
Cross-framework comparison helper. |
Module Contents¶
- ml_switcheroo.generated_tests.runtime.ensure_determinism() None¶
Auto-injects fixed seeds for reproducibility at the start of every test.
Covers: - Python random - NumPy np.random - PyTorch torch.manual_seed (CPU & CUDA) - TensorFlow tf.random.set_seed - MLX mlx.core.random.seed
- ml_switcheroo.generated_tests.runtime.verify_results(ref: Any, val: Any, rtol: float = 0.001, atol: float = 0.0001, exact: bool = False) bool¶
Cross-framework comparison helper.
Recursively compares data structures (Lists, Dicts, Tuples, Arrays).
Modes: - Fuzzy (Default): Uses np.allclose with tolerances for floats. - Exact: Enforces strict equality (ids for None, np.array_equal for arrays).
- Parameters:
ref (Any) – The reference value (e.g. from Source Framework).
val (Any) – The candidate value (e.g. from Target Framework).
rtol (float) – Relative tolerance for floating point comparison.
atol (float) – Absolute tolerance for floating point comparison.
exact (bool) – If True, disables fuzzy matching.
- Returns:
True if values are considered equivalent.
- Return type:
bool