ml_switcheroo.analysis.purity¶
Static Purity Analysis for JAX Compliance.
This module provides the PurityScanner, a LibCST transformer that detects operations unsafe for functional frameworks (like JAX). JAX transformation (JIT, VMap, Grad) requires pure functions with no side effects.
Operations flagged: 1. I/O: print, input, open, write (Standard Python). 2. Global State: global keyword usage. 3. Closure State: nonlocal keyword usage (Feature 05). 4. Structure Mutation: List methods append, extend, etc. 5. Global RNG: Seeding operations (dynamically loaded from semantic config). 6. Framework Impurities: Methods like add_, copy_ loaded from source framework config.
Violations are marked via the EscapeHatch mechanism.
Classes¶
Scans CST for impurities and wraps violations in EscapeHatch markers. |
Module Contents¶
- class ml_switcheroo.analysis.purity.PurityScanner(semantics: Any = None, source_fw: str = 'torch')¶
Bases:
libcst.CSTTransformerScans CST for impurities and wraps violations in EscapeHatch markers.
- _current_violations¶
Accumulator of errors for the current statement.
- Type:
List[str]
- _IO_FUNCTIONS¶
Standard Python I/O function names.
- Type:
Set[str]
- _MUTATION_METHODS¶
Standard Python container mutation methods.
- Type:
Set[str]
- _dynamic_impurity_methods¶
Methods loaded from framework configs (e.g. add_).
- Type:
Set[str]
- _global_rng_methods¶
Methods loaded from framework configs (e.g. manual_seed).
- Type:
Set[str]
- source_fw = 'torch'¶
- visit_SimpleStatementLine(node: libcst.SimpleStatementLine) bool | None¶
Enters a statement line. Resets violation tracking.
- leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) libcst.SimpleStatementLine | libcst.FlattenSentinel¶
Exits a statement line. If violations were found within this statement, wraps it in the EscapeHatch.
- Parameters:
original_node – The original CST node structure.
updated_node – The potentially transformed inner node logic.
- Returns:
The wrapped node if unsafe, otherwise the updated node.
- visit_Global(node: libcst.Global) bool | None¶
Detects usage of the ‘global’ keyword.
- visit_Nonlocal(node: libcst.Nonlocal) bool | None¶
Detects usage of the ‘nonlocal’ keyword.
- visit_Call(node: libcst.Call) bool | None¶
Inspects calls for I/O functions, list mutations, or global RNG seeding.