ml_switcheroo.frameworks.jax¶
JAX Core Framework Adapter (Level 0 & Level 1).
This adapter provides support for the functional JAX ecosystem without binding to a high-level neural network library like Flax or Haiku. It maps: 1. Level 0 (Core): JAX Array API (jnp), Activations (jax.nn), and Types. 2. Level 1 (Common Libs): Optax (Optimization) and Orbax (Checkpointing). 3. IO & Devices: Handles save/load via Orbax and jax.devices mapping.
It specifically enables requires_explicit_rng in plugin traits.
Attributes¶
Classes¶
Adapter for Core JAX (jax + optax + orbax) without a Neural Framework. |
Module Contents¶
- class ml_switcheroo.frameworks.jax.JaxCoreAdapter[source]¶
Bases:
ml_switcheroo.frameworks.common.jax_stack.JAXStackMixinAdapter for Core JAX (jax + optax + orbax) without a Neural Framework.
Handles translations for: - Math: jnp.abs, jnp.sum, etc. - Types: jnp.float32, jnp.int32, jnp.bfloat16. - Casting: .astype(…) synthesis via plugins. - Optimization: optax.adam, optax.sgd.
- display_name: str = 'JAX (Core)'¶
- inherits_from: str | None = None¶
- ui_priority: int = 10¶
- property search_modules: List[str]¶
Scans only core math and optimization libraries (no neural layers).
- Returns:
List of module names.
- Return type:
List[str]
- property unsafe_submodules: Set[str]¶
Returns a set of submodule names to exclude from recursive introspection.
- Returns:
Explicitly empty set (safe to scan default paths).
- Return type:
Set[str]
- property import_alias: Tuple[str, str]¶
Defines the canonical import alias (‘jax.numpy’, ‘jnp’).
- property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]¶
Self-declared namespace roles.
- Returns:
Map of paths to configuration.
- Return type:
Dict[str, ImportConfig]
- property discovery_heuristics: Dict[str, List[str]]¶
Regex patterns for identifying API categories.
- Returns:
Tier to regex patterns mapping.
- Return type:
Dict[str, List[str]]
- property test_config: Dict[str, str]¶
Returns standard JIT-enabled test templates.
- property harness_imports: List[str]¶
Imports required for JAX initialization logic.
- property declared_magic_args: List[str]¶
Returns key as a magic state argument.
- property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶
Defines JAX structural behavior (Transformation rules). Specifies JIT static arguments for compilation safety.
- Returns:
Configuration object.
- Return type:
- property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits¶
Defines logic capabilities for plugins. Enables NumPy compatibility and explicit RNG threading.
IMPORTANT: Enforces Purity Analysis to catch side-effects unsafe for functional trace.
- Returns:
Configuration flags.
- Return type:
- property rng_seed_methods: List[str]¶
JAX does not use global seeding methods in the imperative sense.
- property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶
Static Definitions for JAX Core, Optax, Orbax, and Types. Loaded dynamically from frameworks/definitions/jax.json.
- Returns:
Mapping of definitions.
- Return type:
Dict[str, StandardMap]
- collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) List[ml_switcheroo.frameworks.base.GhostRef][source]¶
Collects API signatures for discovering new Standards. Supports both Live introspection and Ghost Mode snapshots.
- Parameters:
category (StandardCategory) – Category to scan.
- Returns:
Found API references.
- Return type:
List[GhostRef]
- convert(data: Any) Any[source]¶
Converts input data to a JAX array for verification.
- Parameters:
data (Any) – Input data.
- Returns:
JAX Array.
- Return type:
Any
- apply_wiring(snapshot: Dict[str, Any]) None[source]¶
Applies Level 0/1 Stack wiring. Populates the JSON snapshot with manually wired logic.
- Parameters:
snapshot (Dict[str, Any]) – The snapshot to modify.