ml_switcheroo.frameworks.paxml¶

PaxML (Praxis) Framework Adapter (Level 2).

This adapter specializes the core JAX stack for Google’s PaxML framework, specifically targeting the Praxis layer library.

It inherits Level 0 (Core JAX) and Level 1 (Optax/Orbax) capabilities from JAXStackMixin but implements the unique structural traits of the Praxis library, such as the setup() lifecycle method for layer definition.

Attributes¶

praxis

Classes¶

PaxmlAdapter

Adapter for PaxML (Praxis Layers) running on JAX.

Module Contents¶

ml_switcheroo.frameworks.paxml.praxis = None¶
class ml_switcheroo.frameworks.paxml.PaxmlAdapter[source]¶

Bases: ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin

Adapter for PaxML (Praxis Layers) running on JAX.

Features: - Lifecycle Translation: Maps standard __init__ definitions to Praxis setup() methods. - Layer Mapping: Maps Torch/Flax layers to praxis.layers.*. - Stack Reuse: Inherits optimization and math logic from the JAX Core adapter.

display_name: str = 'PaxML / Praxis'¶
inherits_from: str = 'jax'¶
ui_priority: int = 60¶
collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) → List[ml_switcheroo.frameworks.base.GhostRef][source]¶

Collects API definitions for the given category.

Delegates to JaxCoreAdapter for Math, Loss, and Optimizer categories, while handling Layer discovery specifically for Praxis.

Parameters:

category (StandardCategory) – The API category to scan.

Returns:

Found API signatures.

Return type:

List[GhostRef]

property search_modules: List[str]¶

Returns list of modules to scan during manual scaffolding.

Returns:

Module names including praxis.layers and praxis.base_layer.

Return type:

List[str]

property unsafe_submodules: Set[str]¶

Safe defaults.

Returns:

Empty set.

Return type:

Set[str]

property import_alias: Tuple[str, str]¶

Returns the primary import alias for the framework.

Returns:

("praxis.layers", "pl").

Return type:

Tuple[str, str]

property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]¶

Defines the semantic roles of Praxis namespaces.

Returns:

Mapping of namespaces to tiers.

Return type:

Dict[str, ImportConfig]

property discovery_heuristics: Dict[str, List[str]]¶

Returns regex patterns for heuristic categorization.

Returns:

Patterns identifying neural components in Praxis.

Return type:

Dict[str, List[str]]

property test_config: Dict[str, str]¶

Returns templates for generating physical test files. Extends the JAX base config with Praxis imports.

Returns:

Code generation templates.

Return type:

Dict[str, str]

property harness_imports: List[str]¶

Returns imports required for the verification harness.

Returns:

['import jax', 'import jax.random'].

Return type:

List[str]

get_harness_init_code() → str[source]¶

Returns Python code helper for initializing JAX random keys in the harness.

Returns:

Source code for _make_jax_key.

Return type:

str

property supported_tiers: List[Any]¶

Returns supported semantic tiers.

Returns:

Array API, Neural, and Extras.

Return type:

List[SemanticTier]

property declared_magic_args: List[str]¶

Returns list of magic arguments to strip. Praxis usually handles RNG context internally or differently than Flax.

Returns:

Empty list.

Return type:

List[str]

property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶

Defines structural rewriting rules for Praxis.

Key Differences: - Module Base: praxis.base_layer.BaseLayer. - Init Method: Replaces __init__ with setup. - Super Init: Disabled (not required in Praxis setup).

Returns:

The configuration object.

Return type:

StructuralTraits

property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits¶

Returns plugin capability flags. Enables functional control flow and purity analysis (inherited from JAX requirements).

Returns:

The capability flags.

Return type:

PluginTraits

property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶

Returns static definitions for Praxis Layers. Loaded dynamically from frameworks/definitions/paxml.json.

Returns:

The mapping dictionary.

Return type:

Dict[str, StandardMap]

property rng_seed_methods: List[str]¶

Returns list of global RNG seed methods (Empty for PaxML).

Returns:

Empty list.

Return type:

List[str]

convert(data: Any) → Any[source]¶

Converts input data to JAX arrays.

Parameters:

data (Any) – Input data (numpy/list).

Returns:

JAX Array.

Return type:

Any

apply_wiring(snapshot: Dict[str, Any]) → None[source]¶

Applies JAX Stack wiring.

Injects core JAX math operations and Optax optimizer mappings into the snapshot.

Parameters:

snapshot (Dict[str, Any]) – The snapshot dictionary to modify.

get_doc_url(api_name: str) → str | None[source]¶

Generates GitHub search URL for PaxML APIs since documentation is sparse.

Parameters:

api_name – API Path.

Returns:

URL.

Return type:

str

get_tiered_examples() → Dict[str, str][source]¶

Returns tiered example code snippets for documentation.

Returns:

Mapping of tier IDs to code.

Return type:

Dict[str, str]