ml_switcheroo.frameworks.flax_nnx¶
Flax NNX Framework Adapter (Level 2).
This adapter builds upon the core JAX stack to support the Flax NNX
neural network library. It inherits math/optimizer logic from JAXStackMixin
but implements dynamic discovery for NNX Modules.
Refactor: Populates definitions for Neural Layers and Import Namespaces.
Attributes¶
Classes¶
Adapter for the Flax NNX Framework (The Object-Oriented JAX API). |
Module Contents¶
- ml_switcheroo.frameworks.flax_nnx.flax_nnx = None¶
- class ml_switcheroo.frameworks.flax_nnx.FlaxNNXAdapter¶
Bases:
ml_switcheroo.frameworks.common.jax_stack.JAXStackMixinAdapter for the Flax NNX Framework (The Object-Oriented JAX API).
Links standard Neural Layer definitions to
flax.nnx.*.- display_name: str = 'Flax NNX'¶
- inherits_from: str = 'jax'¶
- ui_priority: int = 15¶
- collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) List[ml_switcheroo.frameworks.base.GhostRef]¶
Collects API definitions.
Delegates Loss/Optimizer/Activation scanning to the Core JAX adapter (which scans Optax/JAX.nn), and implements specific logic for scanning Flax NNX Layers.
- property search_modules: List[str]¶
Modules to scan during Scaffolding.
- property import_alias: Tuple[str, str]¶
- property import_namespaces: Dict[str, Dict[str, str]]¶
- property discovery_heuristics: Dict[str, List[str]]¶
- property test_config: Dict[str, str]¶
- property supported_tiers: List[Any]¶
- property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶
Defines Flax NNX specific structural transformations (Level 2). Requires injection of rngs argument for stochastic layers.
- property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶
Static definitions for Flax NNX.
- property rng_seed_methods: List[str]¶
- convert(data)¶
- apply_wiring(snapshot: Dict[str, Any]) None¶
Applies Stack wiring + Flax NNX specific logic.
- classmethod get_example_code() str¶
- get_tiered_examples() Dict[str, str]¶