ml_switcheroo.plugins.nnx_to_torch_params ========================================= .. py:module:: ml_switcheroo.plugins.nnx_to_torch_params .. autoapi-nested-parse:: Plugin for converting Flax NNX Variable definitions to PyTorch Parameters. This module provides AST transformations to handle the impedance mismatch between Flax NNX's explicit variable declarations (`nnx.Param`, `nnx.BatchStat`) and PyTorch's `nn.Parameter` pattern. It handles: 1. **Trainable Parameters**: `nnx.Param(val)` -> `torch.nn.Parameter(val)`. 2. **Non-Trainable State**: `nnx.BatchStat(val)` -> `torch.nn.Parameter(val, requires_grad=False)`. *Note: While PyTorch typically uses `register_buffer` for this, converting an Assignment expression to a `register_buffer` statement is structurally complex in AST replacement. Using non-grad Parameters is a semantic equivalent for state persistence.* Functions --------- .. autoapisummary:: ml_switcheroo.plugins.nnx_to_torch_params.transform_nnx_param Module Contents --------------- .. py:function:: transform_nnx_param(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Plugin Hook: Transforms valid NNX Variable declarations into PyTorch Parameters. Triggers: Operations marked with `requires_plugin: "nnx_param_to_torch"` in Semantic JSONs. Targeting: `flax.nnx.Param`, `flax.nnx.Variable`, `flax.nnx.BatchStat`. Transformation Logic: - Checks the source function signature/name via context or node analysis. - Maps `nnx.Param` -> `torch.nn.Parameter` (Trainable). - Maps `nnx.BatchStat`/`Variable` -> `torch.nn.Parameter(..., requires_grad=False)`. :param node: The original CST Call node (e.g., `nnx.Param(zeros(1))`). :type node: cst.Call :param ctx: The HookContext containing config and semantics. :type ctx: HookContext :returns: The transformed CST Call node. :rtype: cst.Call