ml_switcheroo.frameworks.flax_nnx¶

Flax NNX Framework Adapter (Level 2).

Extends the JAX core adapter with Flax’s Neural Network Extensions (nnx).

  • Uses dynamic or snapshot mode discovery.

  • Provides clear import alias for from flax import nnx.

  • Defines the correct base class flax.nnx.Module.

  • Wires important plugins and structural traits.

Attributes¶

jax

flax_nnx

Classes¶

FlaxNNXAdapter

Adapter class for Flax NNX.

Module Contents¶

ml_switcheroo.frameworks.flax_nnx.jax = None[source]¶
ml_switcheroo.frameworks.flax_nnx.flax_nnx[source]¶
class ml_switcheroo.frameworks.flax_nnx.FlaxNNXAdapter[source]¶

Bases: ml_switcheroo.frameworks.common.jax_stack.JAXStackMixin

Adapter class for Flax NNX.

Inherits from JAXStackMixin for core math/optax behavior and adds: - Explicit neural network layers and activations. - Correct import aliasing for from flax import nnx. - Structural traits targeting Flax’s nnx Module base.

display_name: str = 'Flax NNX'¶
inherits_from: str = 'jax'¶
ui_priority: int = 15¶
collect_api(category: ml_switcheroo.frameworks.base.StandardCategory) → List[ml_switcheroo.frameworks.base.GhostRef][source]¶

Collect API definitions for discovery.

Parameters:

category (StandardCategory) – The API category (layer, activation, etc.)

Returns:

List of found API references.

Return type:

List[GhostRef]

property search_modules: List[str]¶

Modules to scan for discovery.

Returns:

Ordered list of module names.

Return type:

List[str]

property import_alias: Tuple[str, str]¶

Returns the base package and alias to guide import injection. Used by ImportFixer to map flax.nnx root usage to nnx alias.

Returns:

(root_package, alias)

Return type:

Tuple[str, str]

property import_namespaces: Dict[str, ml_switcheroo.frameworks.base.ImportConfig]¶

Declares self namespaces with tiers and recommended aliases.

Returns:

Mapping of package paths to configs.

Return type:

Dict[str, ImportConfig]

property discovery_heuristics: Dict[str, List[str]]¶

Regex patterns for heuristic category assignment during scaffolding.

Returns:

Mapping of tiers to regex lists.

Return type:

Dict[str, List[str]]

property test_config: Dict[str, str]¶

Test code templates extended from JAX core.

Returns:

Test harness code snippets/templates.

Return type:

Dict[str, str]

property harness_imports: List[str]¶

Imports for Harness generation.

get_harness_init_code() → str[source]¶

Logic to create Flax NNX Rngs.

property supported_tiers: List[ml_switcheroo.enums.SemanticTier]¶

Semantic tiers supported by this adapter.

Returns:

Supported tiers.

Return type:

List[SemanticTier]

property declared_magic_args: List[str]¶

Returns list of argument names that represent injected state (‘rngs’).

property structural_traits: ml_switcheroo.frameworks.base.StructuralTraits¶

Structural rewriting traits guiding the pivot rewriter. Explicitly defines flax.nnx.Module to ensure clean inheritance rewriting without internal submodule leakage.

Returns:

Configuration of base class, methods, and injections.

Return type:

StructuralTraits

property plugin_traits: ml_switcheroo.frameworks.base.PluginTraits¶

Plugin capabilities indicating required behaviors in the target framework.

Returns:

Flags controlling plugin execution.

Return type:

PluginTraits

property definitions: Dict[str, ml_switcheroo.frameworks.base.StandardMap]¶

Static standard operation definitions specific to Flax NNX. Loaded dynamically from frameworks/definitions/flax_nnx.json.

Returns:

Mapping of standard op names to framework implementations.

Return type:

Dict[str, StandardMap]

convert(data: Any) → Any[source]¶

Converts generic data to framework-specific Pytree/arrays. Contains self-contained logic to ensure safe extraction by the Harness Generator which does not preserve external dependencies like ‘JaxCoreAdapter’ class references.

Parameters:

data (Any) – Input data (numpy/list).

Returns:

Converted data tailored to JAX/Flax ecosystem.

apply_wiring(snapshot: Dict[str, Any]) → None[source]¶

Applies manual wiring and modifies the snapshot to alias ‘flax.nnx.’ to ‘nnx.’.

Adds plugin wiring for key interface methods ensuring correctness during Ghost Mode synchronization.

Parameters:

snapshot (Dict[str, Any]) – The mapping snapshot dictionary to mutate.

get_tiered_examples() → Dict[str, str][source]¶

Provides tier-specific example usages for documentation and tests.

Returns:

Dictionary mapping tier names to code snippets.

Return type:

Dict[str, str]

get_doc_url(api_name: str) → str | None[source]¶

Returns the official Flax documentation URL for a given API string. Defaults to ReadTheDocs search query for robustness with new NNX APIs.

Parameters:

api_name (str) – The fully qualified API name.

Returns:

The URL to the documentation page.

Return type:

Optional[str]