ml_switcheroo.frameworks.jax¶

JAX Core Framework Adapter (Level 0 & Level 1).

This adapter provides support for the functional JAX ecosystem without binding to a high-level neural network library like Flax or Haiku.

Refactor: Now includes comprehensive definitions for JAX ops and Optax/Orbax via definitions.

Attributes¶

jax

JaxAdapter

Classes¶

JaxCoreAdapter

Adapter for Core JAX (jax + optax + orbax) without a Neural Framework.

Module Contents¶

ml_switcheroo.frameworks.jax.jax = None¶
class ml_switcheroo.frameworks.jax.JaxCoreAdapter¶

Bases: ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin

Adapter for Core JAX (jax + optax + orbax) without a Neural Framework.

display_name: str = 'JAX (Core)'¶
inherits_from: str | None = None¶
ui_priority: int = 10¶
property search_modules: List[str]¶

Scans only core math and optimization libraries.

property import_alias: Tuple[str, str]¶

Defines the canonical import alias for the framework root. Ensures import jax.numpy as jnp is generated.

property import_namespaces: Dict[str, Dict[str, str]]¶

Maps source namespaces to JAX equivalents.

property discovery_heuristics: Dict[str, List[str]]¶
property test_config: Dict[str, str]¶
property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶
property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶

Static Definitions for JAX Core, Optax, Orbax.

property rng_seed_methods: List[str]¶
collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) → List[ml_switcheroo.frameworks.base.GhostRef]¶
convert(data)¶
apply_wiring(snapshot: Dict[str, Any]) → None¶

Applies Level 0/1 Stack wiring.

classmethod get_example_code() → str¶
ml_switcheroo.frameworks.jax.JaxAdapter¶