ml_switcheroo.frameworks.paxml¶
PaxML (Praxis) Framework Adapter (Level 2).
This adapter specializes the core JAX stack for Google’s PaxML framework, specifically targeting the Praxis layer library.
It inherits Level 0 (Core JAX) and Level 1 (Optax/Orbax) capabilities from
JAXStackMixin but implements the unique structural traits of the Praxis library,
such as the setup() lifecycle method for layer definition.
Attributes¶
Classes¶
Adapter for PaxML (Praxis Layers) running on JAX. |
Module Contents¶
- ml_switcheroo.frameworks.paxml.praxis = None¶
- class ml_switcheroo.frameworks.paxml.PaxmlAdapter[source]¶
Bases:
ml_switcheroo.frameworks.common.jax_stack.JAXStackMixinAdapter for PaxML (Praxis Layers) running on JAX.
Features: - Lifecycle Translation: Maps standard
__init__definitions to Praxissetup()methods. - Layer Mapping: Maps Torch/Flax layers topraxis.layers.*. - Stack Reuse: Inherits optimization and math logic from the JAX Core adapter.- display_name: str = 'PaxML / Praxis'¶
- inherits_from: str = 'jax'¶
- ui_priority: int = 60¶
- collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) List[ml_switcheroo.frameworks.base.GhostRef][source]¶
Collects API definitions for the given category.
Delegates to
JaxCoreAdapterfor Math, Loss, and Optimizer categories, while handling Layer discovery specifically for Praxis.- Parameters:
category (StandardCategory) – The API category to scan.
- Returns:
Found API signatures.
- Return type:
List[GhostRef]
- property search_modules: List[str]¶
Returns list of modules to scan during manual scaffolding.
- Returns:
Module names including
praxis.layersandpraxis.base_layer.- Return type:
List[str]
- property unsafe_submodules: Set[str]¶
Safe defaults.
- Returns:
Empty set.
- Return type:
Set[str]
- property import_alias: Tuple[str, str]¶
Returns the primary import alias for the framework.
- Returns:
("praxis.layers", "pl").- Return type:
Tuple[str, str]
- property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]¶
Defines the semantic roles of Praxis namespaces.
- Returns:
Mapping of namespaces to tiers.
- Return type:
Dict[str, ImportConfig]
- property discovery_heuristics: Dict[str, List[str]]¶
Returns regex patterns for heuristic categorization.
- Returns:
Patterns identifying neural components in Praxis.
- Return type:
Dict[str, List[str]]
- property test_config: Dict[str, str]¶
Returns templates for generating physical test files. Extends the JAX base config with Praxis imports.
- Returns:
Code generation templates.
- Return type:
Dict[str, str]
- property harness_imports: List[str]¶
Returns imports required for the verification harness.
- Returns:
['import jax', 'import jax.random'].- Return type:
List[str]
- get_harness_init_code() str[source]¶
Returns Python code helper for initializing JAX random keys in the harness.
- Returns:
Source code for
_make_jax_key.- Return type:
str
- property supported_tiers: List[Any]¶
Returns supported semantic tiers.
- Returns:
Array API, Neural, and Extras.
- Return type:
List[SemanticTier]
- property declared_magic_args: List[str]¶
Returns list of magic arguments to strip. Praxis usually handles RNG context internally or differently than Flax.
- Returns:
Empty list.
- Return type:
List[str]
- property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶
Defines structural rewriting rules for Praxis.
Key Differences: - Module Base:
praxis.base_layer.BaseLayer. - Init Method: Replaces__init__withsetup. - Super Init: Disabled (not required in Praxis setup).- Returns:
The configuration object.
- Return type:
- property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits¶
Returns plugin capability flags. Enables functional control flow and purity analysis (inherited from JAX requirements).
- Returns:
The capability flags.
- Return type:
- property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶
Returns static definitions for Praxis Layers. Loaded dynamically from frameworks/definitions/paxml.json.
- Returns:
The mapping dictionary.
- Return type:
Dict[str, StandardMap]
- property rng_seed_methods: List[str]¶
Returns list of global RNG seed methods (Empty for PaxML).
- Returns:
Empty list.
- Return type:
List[str]
- convert(data: Any) Any[source]¶
Converts input data to JAX arrays.
- Parameters:
data (Any) – Input data (numpy/list).
- Returns:
JAX Array.
- Return type:
Any
- apply_wiring(snapshot: Dict[str, Any]) None[source]¶
Applies JAX Stack wiring.
Injects core JAX math operations and Optax optimizer mappings into the snapshot.
- Parameters:
snapshot (Dict[str, Any]) – The snapshot dictionary to modify.