ml_switcheroo.frameworks.jax ============================ .. py:module:: ml_switcheroo.frameworks.jax .. autoapi-nested-parse:: JAX Core Framework Adapter (Level 0 & Level 1). This adapter provides support for the functional JAX ecosystem *without* binding to a high-level neural network library like Flax or Haiku. Refactor: Now includes comprehensive definitions for JAX ops and Optax/Orbax via `definitions`. Attributes ---------- .. autoapisummary:: ml_switcheroo.frameworks.jax.jax ml_switcheroo.frameworks.jax.JaxAdapter Classes ------- .. autoapisummary:: ml_switcheroo.frameworks.jax.JaxCoreAdapter Module Contents --------------- .. py:data:: jax :value: None .. py:class:: JaxCoreAdapter Bases: :py:obj:`ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin` Adapter for Core JAX (jax + optax + orbax) without a Neural Framework. .. py:attribute:: display_name :type: str :value: 'JAX (Core)' .. py:attribute:: inherits_from :type: Optional[str] :value: None .. py:attribute:: ui_priority :type: int :value: 10 .. py:property:: search_modules :type: List[str] Scans only core math and optimization libraries. .. py:property:: import_alias :type: Tuple[str, str] Defines the canonical import alias for the framework root. Ensures `import jax.numpy as jnp` is generated. .. py:property:: import_namespaces :type: Dict[str, Dict[str, str]] Maps source namespaces to JAX equivalents. .. py:property:: discovery_heuristics :type: Dict[str, List[str]] .. py:property:: test_config :type: Dict[str, str] .. py:property:: structural_traits :type: ml_switcheroo.frameworks.base.StructuralTraits .. py:property:: definitions :type: Dict[str, ml_switcheroo.frameworks.base.StandardMap] Static Definitions for JAX Core, Optax, Orbax. .. py:property:: rng_seed_methods :type: List[str] .. py:method:: collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) -> List[ml_switcheroo.frameworks.base.GhostRef] .. py:method:: convert(data) .. py:method:: apply_wiring(snapshot: Dict[str, Any]) -> None Applies Level 0/1 Stack wiring. .. py:method:: get_example_code() -> str :classmethod: .. py:data:: JaxAdapter