ml_switcheroo.plugins.state_container ===================================== .. py:module:: ml_switcheroo.plugins.state_container .. autoapi-nested-parse:: Plugin for handling Stateful Container Logic. Handles mapping of container management methods between frameworks, particularly impedance mismatches between PyTorch's imperative `register_buffer`/`parameters` system and Flax NNX's explicit state management. Mappings (Torch -> JAX/NNX): 1. `self.register_buffer("name", t)` -> `setattr(self, "name", flax.nnx.BatchStat(t))` 2. `self.register_parameter("name", p)` -> `setattr(self, "name", flax.nnx.Param(p))` 3. `model.state_dict()` -> `flax.nnx.state(model).to_pure_dict()` 4. `model.load_state_dict(sd)` -> `flax.nnx.update(model, sd)` 5. `model.parameters()` -> `flax.nnx.state(model, flax.nnx.Param).values()` Functions --------- .. autoapisummary:: ml_switcheroo.plugins.state_container.convert_register_buffer ml_switcheroo.plugins.state_container.convert_register_parameter ml_switcheroo.plugins.state_container.convert_state_dict ml_switcheroo.plugins.state_container.convert_load_state_dict ml_switcheroo.plugins.state_container.convert_parameters Module Contents --------------- .. py:function:: convert_register_buffer(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Transforms `self.register_buffer('name', tensor)` -> `setattr(self, 'name', nnx.BatchStat(tensor))`. .. py:function:: convert_register_parameter(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Transforms `self.register_parameter('name', param)` -> `setattr(self, 'name', nnx.Param(param))`. .. py:function:: convert_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Transforms `model.state_dict()` -> `flax.nnx.state(model).to_pure_dict()`. .. py:function:: convert_load_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Transforms `model.load_state_dict(state)` -> `flax.nnx.update(model, state)`. .. py:function:: convert_parameters(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Transforms `model.parameters()` -> `flax.nnx.state(model, flax.nnx.Param).values()`.