ml_switcheroo.core.rewriter.passes.structure ============================================ .. py:module:: ml_switcheroo.core.rewriter.passes.structure .. autoapi-nested-parse:: Structural Rewriting Pass. This module consolidates all structural transformation logic, including: 1. **Class Inheritance Rewriting**: Swapping framework base classes (e.g., ``torch.nn.Module`` -> ``flax.nnx.Module``). 2. **Function Signature Rewriting**: Injecting or stripping state/context arguments (e.g., ``rngs``, ``ctx``). 3. **Method Renaming**: Mapping lifecycle methods (e.g., ``forward`` -> ``__call__``). 4. **Body Injection**: Handling ``super().__init__`` calls and preamble injection. 5. **Type Annotation Rewriting**: Mapping framework-specific types (e.g., ``torch.Tensor`` -> ``jax.Array``). Classes ------- .. autoapisummary:: ml_switcheroo.core.rewriter.passes.structure.StructuralPass ml_switcheroo.core.rewriter.passes.structure.StructuralTransformer Module Contents --------------- .. py:class:: StructuralPass Bases: :py:obj:`ml_switcheroo.core.rewriter.interface.RewriterPass` Pass responsible for modifying the structural scaffolding of the code. It transforms class definitions, function signatures, and type hints to match the idioms of the target framework. .. py:method:: transform(module: libcst.Module, context: ml_switcheroo.core.rewriter.context.RewriterContext) -> libcst.Module Executes the structural transformation. :param module: The source CST. :param context: Shared rewriter state. :returns: The transformed CST. .. py:class:: StructuralTransformer(context: ml_switcheroo.core.rewriter.context.RewriterContext) Bases: :py:obj:`libcst.CSTTransformer` LibCST Transformer encapsulating all structural logic. Maintains internal state for scope depth, annotation context, and target framework traits. .. py:attribute:: context .. py:property:: target_traits :type: ml_switcheroo.semantics.schema.StructuralTraits Lazily load structural traits for the target framework. .. py:method:: leave_Module(original_node: libcst.Module, updated_node: libcst.Module) -> libcst.Module Injects accumulated module-level preamble statements (e.g. imports, shim classes) requested by plugins during the rewrite. Flushes and clears the buffer to prevent double injection in subsequent passes. .. py:method:: visit_Annotation(node: libcst.Annotation) -> Optional[bool] Flag entry into a type annotation context. .. py:method:: leave_Annotation(original_node: libcst.Annotation, updated_node: libcst.Annotation) -> libcst.Annotation Flag exit from a type annotation context. .. py:method:: leave_Name(original_node: libcst.Name, updated_node: libcst.Name) -> libcst.BaseExpression Rewrite type names if inside an annotation. .. py:method:: leave_Attribute(original_node: libcst.Attribute, updated_node: libcst.Attribute) -> libcst.BaseExpression Rewrite dotted type attributes (e.g. torch.Tensor) if inside an annotation. This fixes issue where complex types like torch.Tensor fail to rewrite because leave_Name only handles the leaves individually without context. .. py:method:: visit_ClassDef(node: libcst.ClassDef) -> Optional[bool] Detect Module inheritance to set processing state. .. py:method:: leave_ClassDef(original_node: libcst.ClassDef, updated_node: libcst.ClassDef) -> Union[libcst.ClassDef, libcst.CSTNode] Rewrite class inheritance if necessary. .. py:method:: visit_FunctionDef(node: libcst.FunctionDef) -> Optional[bool] Push function context onto signature stack. .. py:method:: leave_FunctionDef(original_node: libcst.FunctionDef, updated_node: libcst.FunctionDef) -> libcst.FunctionDef Apply signature, name, and body writes to functions.