ml_switcheroo.generated_tests.templates¶

Test Generation Templates and configuration descriptors.

This module stores default code templates for supported frameworks and provides utilities to determine properties of test arguments (e.g., static JIT args).

Attributes¶

DEFAULT_TEST_TEMPLATES

Functions¶

get_template(→ Dict[str, str])

Retrieves the code generation template for a specific framework.

is_static_arg(→ bool)

Determines if an argument should be marked static for JIT compilation.

Module Contents¶

ml_switcheroo.generated_tests.templates.DEFAULT_TEST_TEMPLATES¶
ml_switcheroo.generated_tests.templates.get_template(manager: Any, framework: str) → Dict[str, str]¶

Retrieves the code generation template for a specific framework.

Priority: 1. SemanticsManager lookup (loaded from snapshots). 2. Hardcoded defaults in DEFAULT_TEST_TEMPLATES. 3. Empty dict.

Parameters:
  • manager – The SemanticsManager instance (can be None).

  • framework – The framework key (e.g., ‘torch’, ‘jax’).

Returns:

Template strings for imports, conversion, etc.

Return type:

Dict[str, str]

ml_switcheroo.generated_tests.templates.is_static_arg(arg_info: Dict[str, Any]) → bool¶

Determines if an argument should be marked static for JIT compilation.

Heuristic checks for primitive types (int, bool, str) or specific names common to axis/dimension arguments.

Parameters:

arg_info – A dictionary containing ‘name’ and ‘type’ keys.

Returns:

True if the argument should be static.

Return type:

bool