ml_switcheroo.analysis.purity¶

Static Purity Analysis for JAX Compliance.

This module provides the PurityScanner, a LibCST transformer that detects operations unsafe for functional frameworks (like JAX). JAX transformation (JIT, VMap, Grad) requires pure functions with no side effects.

Operations flagged: 1. I/O: print, input, open, write (Standard Python). 2. Global State: global keyword usage. 3. Closure State: nonlocal keyword usage (Feature 05). 4. Structure Mutation: List methods append, extend, etc. 5. Global RNG: Seeding operations (dynamically loaded from semantic config). 6. Framework Impurities: Methods like add_, copy_ loaded from source framework config.

Violations are marked via the EscapeHatch mechanism.

Classes¶

PurityScanner

Scans CST for impurities and wraps violations in EscapeHatch markers.

Module Contents¶

class ml_switcheroo.analysis.purity.PurityScanner(semantics: Any = None, source_fw: str = 'torch')¶

Bases: libcst.CSTTransformer

Scans CST for impurities and wraps violations in EscapeHatch markers.

_current_violations¶

Accumulator of errors for the current statement.

Type:

List[str]

_IO_FUNCTIONS¶

Standard Python I/O function names.

Type:

Set[str]

_MUTATION_METHODS¶

Standard Python container mutation methods.

Type:

Set[str]

_dynamic_impurity_methods¶

Methods loaded from framework configs (e.g. add_).

Type:

Set[str]

_global_rng_methods¶

Methods loaded from framework configs (e.g. manual_seed).

Type:

Set[str]

source_fw = 'torch'¶
visit_SimpleStatementLine(node: libcst.SimpleStatementLine) → bool | None¶

Enters a statement line. Resets violation tracking.

leave_SimpleStatementLine(original_node: libcst.SimpleStatementLine, updated_node: libcst.SimpleStatementLine) → libcst.SimpleStatementLine | libcst.FlattenSentinel¶

Exits a statement line. If violations were found within this statement, wraps it in the EscapeHatch.

Parameters:
  • original_node – The original CST node structure.

  • updated_node – The potentially transformed inner node logic.

Returns:

The wrapped node if unsafe, otherwise the updated node.

visit_Global(node: libcst.Global) → bool | None¶

Detects usage of the ‘global’ keyword.

visit_Nonlocal(node: libcst.Nonlocal) → bool | None¶

Detects usage of the ‘nonlocal’ keyword.

visit_Call(node: libcst.Call) → bool | None¶

Inspects calls for I/O functions, list mutations, or global RNG seeding.