ml_switcheroo.plugins.nnx_to_torch_params¶
Plugin for converting Flax NNX Variable definitions to PyTorch-style Parameters.
This module provides AST transformations to handle the impedance mismatch between Flax NNX’s explicit variable declarations (nnx.Param, nnx.BatchStat) and PyTorch’s nn.Parameter pattern.
It handles:
Trainable Parameters: nnx.Param(val) -> TargetParam(val).
Non-Trainable State: nnx.BatchStat(val) -> TargetParam(val, requires_grad=False).
- Decoupling:
Logic is triggered solely by the requires_plugin=”nnx_param_to_torch” wiring. The target class name is resolved via ctx.lookup_api based on the abstract Operation ID. If no mapping is found in the Knowledge Base, the transformation aborts (returns original node), preventing hardcoded fallbacks to torch.nn.Parameter.
Functions¶
|
Plugin Hook: Transforms valid NNX Variable declarations into PyTorch-style Parameters. |
Module Contents¶
- ml_switcheroo.plugins.nnx_to_torch_params.transform_nnx_param(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Plugin Hook: Transforms valid NNX Variable declarations into PyTorch-style Parameters.
- Triggers:
Operations marked with requires_plugin: “nnx_param_to_torch”. Targeting: flax.nnx.Param, flax.nnx.Variable, flax.nnx.BatchStat.
- Logic:
Determines if the source variable was trainable (Param) or not (BatchStat, Variable).
Looks up the target API for the current operation.
Injects requires_grad=False if not trainable.
- Parameters:
node – The original CST Call node (e.g., nnx.Param(zeros(1))).
ctx – The HookContext containing configuration.
- Returns:
The transformed CST Call node or original if mapping missing.
- Return type:
cst.Call