ml_switcheroo.frameworks.flax_nnx¶
Flax NNX Framework Adapter (Level 2).
Extends the JAX core adapter with Flax’s Neural Network Extensions (nnx).
Uses dynamic or snapshot mode discovery.
Provides clear import alias for from flax import nnx.
Defines the correct base class flax.nnx.Module.
Wires important plugins and structural traits.
Attributes¶
Classes¶
Adapter class for Flax NNX. |
Module Contents¶
- class ml_switcheroo.frameworks.flax_nnx.FlaxNNXAdapter[source]¶
Bases:
ml_switcheroo.frameworks.common.jax_stack.JAXStackMixinAdapter class for Flax NNX.
Inherits from JAXStackMixin for core math/optax behavior and adds: - Explicit neural network layers and activations. - Correct import aliasing for from flax import nnx. - Structural traits targeting Flax’s nnx Module base.
- display_name: str = 'Flax NNX'¶
- inherits_from: str = 'jax'¶
- ui_priority: int = 15¶
- collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) List[ml_switcheroo.frameworks.base.GhostRef][source]¶
Collect API definitions for discovery.
- Parameters:
category (StandardCategory) – The API category (layer, activation, etc.)
- Returns:
List of found API references.
- Return type:
List[GhostRef]
- property search_modules: List[str]¶
Modules to scan for discovery.
- Returns:
Ordered list of module names.
- Return type:
List[str]
- property import_alias: Tuple[str, str]¶
Returns the base package and alias to guide import injection. Used by ImportFixer to map flax.nnx root usage to nnx alias.
- Returns:
(root_package, alias)
- Return type:
Tuple[str, str]
- property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]¶
Declares self namespaces with tiers and recommended aliases.
- Returns:
Mapping of package paths to configs.
- Return type:
Dict[str, ImportConfig]
- property discovery_heuristics: Dict[str, List[str]]¶
Regex patterns for heuristic category assignment during scaffolding.
- Returns:
Mapping of tiers to regex lists.
- Return type:
Dict[str, List[str]]
- property test_config: Dict[str, str]¶
Test code templates extended from JAX core.
- Returns:
Test harness code snippets/templates.
- Return type:
Dict[str, str]
- property harness_imports: List[str]¶
Imports for Harness generation.
- property supported_tiers: List[ml_switcheroo.enums.SemanticTier]¶
Semantic tiers supported by this adapter.
- Returns:
Supported tiers.
- Return type:
List[SemanticTier]
- property declared_magic_args: List[str]¶
Returns list of argument names that represent injected state (‘rngs’).
- property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶
Structural rewriting traits guiding the pivot rewriter. Explicitly defines flax.nnx.Module to ensure clean inheritance rewriting without internal submodule leakage.
- Returns:
Configuration of base class, methods, and injections.
- Return type:
- property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits¶
Plugin capabilities indicating required behaviors in the target framework.
- Returns:
Flags controlling plugin execution.
- Return type:
- property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶
Static standard operation definitions specific to Flax NNX. Loaded dynamically from frameworks/definitions/flax_nnx.json.
- Returns:
Mapping of standard op names to framework implementations.
- Return type:
Dict[str, StandardMap]
- convert(data: Any) Any[source]¶
Converts generic data to framework-specific Pytree/arrays. Contains self-contained logic to ensure safe extraction by the Harness Generator which does not preserve external dependencies like ‘JaxCoreAdapter’ class references.
- Parameters:
data (Any) – Input data (numpy/list).
- Returns:
Converted data tailored to JAX/Flax ecosystem.
- apply_wiring(snapshot: Dict[str, Any]) None[source]¶
Applies manual wiring and modifies the snapshot to alias ‘flax.nnx.’ to ‘nnx.’.
Adds plugin wiring for key interface methods ensuring correctness during Ghost Mode synchronization.
- Parameters:
snapshot (Dict[str, Any]) – The mapping snapshot dictionary to mutate.
- get_tiered_examples() Dict[str, str][source]¶
Provides tier-specific example usages for documentation and tests.
- Returns:
Dictionary mapping tier names to code snippets.
- Return type:
Dict[str, str]
- get_doc_url(api_name: str) str | None[source]¶
Returns the official Flax documentation URL for a given API string. Defaults to ReadTheDocs search query for robustness with new NNX APIs.
- Parameters:
api_name (str) – The fully qualified API name.
- Returns:
The URL to the documentation page.
- Return type:
Optional[str]