ml_switcheroo.plugins.nnx_to_torch_params

Plugin for converting Flax NNX Variable definitions to PyTorch Parameters.

This module provides AST transformations to handle the impedance mismatch between Flax NNX’s explicit variable declarations (nnx.Param, nnx.BatchStat) and PyTorch’s nn.Parameter pattern.

It handles: 1. Trainable Parameters: nnx.Param(val) -> torch.nn.Parameter(val). 2. Non-Trainable State: nnx.BatchStat(val) -> torch.nn.Parameter(val, requires_grad=False).

Note: While PyTorch typically uses `register_buffer` for this, converting an Assignment expression to a `register_buffer` statement is structurally complex in AST replacement. Using non-grad Parameters is a semantic equivalent for state persistence.

Functions

transform_nnx_param(→ libcst.Call)

Plugin Hook: Transforms valid NNX Variable declarations into PyTorch Parameters.

Module Contents

ml_switcheroo.plugins.nnx_to_torch_params.transform_nnx_param(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Plugin Hook: Transforms valid NNX Variable declarations into PyTorch Parameters.

Triggers:

Operations marked with requires_plugin: “nnx_param_to_torch” in Semantic JSONs. Targeting: flax.nnx.Param, flax.nnx.Variable, flax.nnx.BatchStat.

Transformation Logic:
  • Checks the source function signature/name via context or node analysis.

  • Maps nnx.Param -> torch.nn.Parameter (Trainable).

  • Maps nnx.BatchStat/Variable -> torch.nn.Parameter(…, requires_grad=False).

Parameters:
  • node (cst.Call) – The original CST Call node (e.g., nnx.Param(zeros(1))).

  • ctx (HookContext) – The HookContext containing config and semantics.

Returns:

The transformed CST Call node.

Return type:

cst.Call