ml_switcheroo.frameworks.jax¶

JAX Core Framework Adapter (Level 0 & Level 1).

This adapter provides support for the functional JAX ecosystem without binding to a high-level neural network library like Flax or Haiku. It maps: 1. Level 0 (Core): JAX Array API (jnp), Activations (jax.nn), and Types. 2. Level 1 (Common Libs): Optax (Optimization) and Orbax (Checkpointing). 3. IO & Devices: Handles save/load via Orbax and jax.devices mapping.

It specifically enables requires_explicit_rng in plugin traits.

Attributes¶

jax

Classes¶

JaxCoreAdapter

Adapter for Core JAX (jax + optax + orbax) without a Neural Framework.

Module Contents¶

ml_switcheroo.frameworks.jax.jax = None[source]¶
class ml_switcheroo.frameworks.jax.JaxCoreAdapter[source]¶

Bases: ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin

Adapter for Core JAX (jax + optax + orbax) without a Neural Framework.

Handles translations for: - Math: jnp.abs, jnp.sum, etc. - Types: jnp.float32, jnp.int32, jnp.bfloat16. - Casting: .astype(…) synthesis via plugins. - Optimization: optax.adam, optax.sgd.

display_name: str = 'JAX (Core)'¶
inherits_from: str | None = None¶
ui_priority: int = 10¶
property search_modules: List[str]¶

Scans only core math and optimization libraries (no neural layers).

Returns:

List of module names.

Return type:

List[str]

property unsafe_submodules: Set[str]¶

Returns a set of submodule names to exclude from recursive introspection.

Returns:

Explicitly empty set (safe to scan default paths).

Return type:

Set[str]

property import_alias: Tuple[str, str]¶

Defines the canonical import alias (‘jax.numpy’, ‘jnp’).

property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]¶

Self-declared namespace roles.

Returns:

Map of paths to configuration.

Return type:

Dict[str, ImportConfig]

property discovery_heuristics: Dict[str, List[str]]¶

Regex patterns for identifying API categories.

Returns:

Tier to regex patterns mapping.

Return type:

Dict[str, List[str]]

property test_config: Dict[str, str]¶

Returns standard JIT-enabled test templates.

property harness_imports: List[str]¶

Imports required for JAX initialization logic.

get_harness_init_code() → str[source]¶

Returns logic to create JAX PRNG Keys.

property declared_magic_args: List[str]¶

Returns key as a magic state argument.

property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶

Defines JAX structural behavior (Transformation rules). Specifies JIT static arguments for compilation safety.

Returns:

Configuration object.

Return type:

StructuralTraits

property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits¶

Defines logic capabilities for plugins. Enables NumPy compatibility and explicit RNG threading.

IMPORTANT: Enforces Purity Analysis to catch side-effects unsafe for functional trace.

Returns:

Configuration flags.

Return type:

PluginTraits

property rng_seed_methods: List[str]¶

JAX does not use global seeding methods in the imperative sense.

property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶

Static Definitions for JAX Core, Optax, Orbax, and Types. Loaded dynamically from frameworks/definitions/jax.json.

Returns:

Mapping of definitions.

Return type:

Dict[str, StandardMap]

collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) → List[ml_switcheroo.frameworks.base.GhostRef][source]¶

Collects API signatures for discovering new Standards. Supports both Live introspection and Ghost Mode snapshots.

Parameters:

category (StandardCategory) – Category to scan.

Returns:

Found API references.

Return type:

List[GhostRef]

convert(data: Any) → Any[source]¶

Converts input data to a JAX array for verification.

Parameters:

data (Any) – Input data.

Returns:

JAX Array.

Return type:

Any

apply_wiring(snapshot: Dict[str, Any]) → None[source]¶

Applies Level 0/1 Stack wiring. Populates the JSON snapshot with manually wired logic.

Parameters:

snapshot (Dict[str, Any]) – The snapshot to modify.

get_tiered_examples() → Dict[str, str][source]¶

Provides default tiered examples for the base adapter.

Returns:

Mapping of tier name to source code.

Return type:

Dict[str, str]

get_doc_url(api_name: str) → str | None[source]¶

Generates JAX core documentation URL.

Parameters:

api_name (str) – API path.

Returns:

URL string.

Return type:

Optional[str]