ml_switcheroo.frameworks.jax¶
JAX Core Framework Adapter (Level 0 & Level 1).
This adapter provides support for the functional JAX ecosystem without binding to a high-level neural network library like Flax or Haiku.
Refactor: Now includes comprehensive definitions for JAX ops and Optax/Orbax via definitions.
Attributes¶
Classes¶
Adapter for Core JAX (jax + optax + orbax) without a Neural Framework. |
Module Contents¶
- ml_switcheroo.frameworks.jax.jax = None¶
- class ml_switcheroo.frameworks.jax.JaxCoreAdapter¶
Bases:
ml_switcheroo.frameworks.common.jax_stack.JAXStackMixinAdapter for Core JAX (jax + optax + orbax) without a Neural Framework.
- display_name: str = 'JAX (Core)'¶
- inherits_from: str | None = None¶
- ui_priority: int = 10¶
- property search_modules: List[str]¶
Scans only core math and optimization libraries.
- property import_alias: Tuple[str, str]¶
Defines the canonical import alias for the framework root. Ensures import jax.numpy as jnp is generated.
- property import_namespaces: Dict[str, Dict[str, str]]¶
Maps source namespaces to JAX equivalents.
- property discovery_heuristics: Dict[str, List[str]]¶
- property test_config: Dict[str, str]¶
- property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶
- property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶
Static Definitions for JAX Core, Optax, Orbax.
- property rng_seed_methods: List[str]¶
- collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) List[ml_switcheroo.frameworks.base.GhostRef]¶
- convert(data)¶
- apply_wiring(snapshot: Dict[str, Any]) None¶
Applies Level 0/1 Stack wiring.
- classmethod get_example_code() str¶
- ml_switcheroo.frameworks.jax.JaxAdapter¶