ml_switcheroo.plugins.nnx_to_torch_params

Plugin for converting Flax NNX Variable definitions to PyTorch-style Parameters.

This module provides AST transformations to handle the impedance mismatch between Flax NNX’s explicit variable declarations (nnx.Param, nnx.BatchStat) and PyTorch’s nn.Parameter pattern.

It handles:

  1. Trainable Parameters: nnx.Param(val) -> TargetParam(val).

  2. Non-Trainable State: nnx.BatchStat(val) -> TargetParam(val, requires_grad=False).

Decoupling:

Logic is triggered solely by the requires_plugin=”nnx_param_to_torch” wiring. The target class name is resolved via ctx.lookup_api based on the abstract Operation ID. If no mapping is found in the Knowledge Base, the transformation aborts (returns original node), preventing hardcoded fallbacks to torch.nn.Parameter.

Functions

transform_nnx_param(→ libcst.Call)

Plugin Hook: Transforms valid NNX Variable declarations into PyTorch-style Parameters.

Module Contents

ml_switcheroo.plugins.nnx_to_torch_params.transform_nnx_param(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]

Plugin Hook: Transforms valid NNX Variable declarations into PyTorch-style Parameters.

Triggers:

Operations marked with requires_plugin: “nnx_param_to_torch”. Targeting: flax.nnx.Param, flax.nnx.Variable, flax.nnx.BatchStat.

Logic:
  1. Determines if the source variable was trainable (Param) or not (BatchStat, Variable).

  2. Looks up the target API for the current operation.

  3. Injects requires_grad=False if not trainable.

Parameters:
  • node – The original CST Call node (e.g., nnx.Param(zeros(1))).

  • ctx – The HookContext containing configuration.

Returns:

The transformed CST Call node or original if mapping missing.

Return type:

cst.Call