ml_switcheroo.frameworks.flax_nnx¶

Flax NNX Framework Adapter (Level 2).

This adapter builds upon the core JAX stack to support the Flax NNX neural network library. It inherits math/optimizer logic from JAXStackMixin but implements dynamic discovery for NNX Modules.

Refactor: Populates definitions for Neural Layers and Import Namespaces.

Attributes¶

flax_nnx

Classes¶

FlaxNNXAdapter

Adapter for the Flax NNX Framework (The Object-Oriented JAX API).

Module Contents¶

ml_switcheroo.frameworks.flax_nnx.flax_nnx = None¶
class ml_switcheroo.frameworks.flax_nnx.FlaxNNXAdapter¶

Bases: ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin

Adapter for the Flax NNX Framework (The Object-Oriented JAX API).

Links standard Neural Layer definitions to flax.nnx.*.

display_name: str = 'Flax NNX'¶
inherits_from: str = 'jax'¶
ui_priority: int = 15¶
collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) → List[ml_switcheroo.frameworks.base.GhostRef]¶

Collects API definitions.

Delegates Loss/Optimizer/Activation scanning to the Core JAX adapter (which scans Optax/JAX.nn), and implements specific logic for scanning Flax NNX Layers.

property search_modules: List[str]¶

Modules to scan during Scaffolding.

property import_alias: Tuple[str, str]¶
property import_namespaces: Dict[str, Dict[str, str]]¶
property discovery_heuristics: Dict[str, List[str]]¶
property test_config: Dict[str, str]¶
property supported_tiers: List[Any]¶
property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶

Defines Flax NNX specific structural transformations (Level 2). Requires injection of rngs argument for stochastic layers.

property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶

Static definitions for Flax NNX.

property rng_seed_methods: List[str]¶
convert(data)¶
apply_wiring(snapshot: Dict[str, Any]) → None¶

Applies Stack wiring + Flax NNX specific logic.

classmethod get_example_code() → str¶
get_tiered_examples() → Dict[str, str]¶