ml_switcheroo.generated_tests.templates¶
Test Generation Templates and configuration descriptors.
This module stores default code templates for supported frameworks and provides utilities to determine properties of test arguments (e.g., static JIT args).
Attributes¶
Functions¶
|
Retrieves the code generation template for a specific framework. |
|
Determines if an argument should be marked static for JIT compilation. |
Module Contents¶
- ml_switcheroo.generated_tests.templates.DEFAULT_TEST_TEMPLATES¶
- ml_switcheroo.generated_tests.templates.get_template(manager: Any, framework: str) Dict[str, str]¶
Retrieves the code generation template for a specific framework.
Priority: 1. SemanticsManager lookup (loaded from snapshots). 2. Hardcoded defaults in DEFAULT_TEST_TEMPLATES. 3. Empty dict.
- Parameters:
manager – The SemanticsManager instance (can be None).
framework – The framework key (e.g., ‘torch’, ‘jax’).
- Returns:
Template strings for imports, conversion, etc.
- Return type:
Dict[str, str]
- ml_switcheroo.generated_tests.templates.is_static_arg(arg_info: Dict[str, Any]) bool¶
Determines if an argument should be marked static for JIT compilation.
Heuristic checks for primitive types (int, bool, str) or specific names common to axis/dimension arguments.
- Parameters:
arg_info – A dictionary containing ‘name’ and ‘type’ keys.
- Returns:
True if the argument should be static.
- Return type:
bool