ml_switcheroo.plugins.state_container¶
Plugin for handling Stateful Container Logic.
Handles mapping of container management methods between frameworks, particularly impedance mismatches between imperative usage (PyTorch’s register_buffer) and functional explicit state management (JAX/Flax/MLX).
This module converts: 1. self.register_buffer(“name”, t) -> setattr(self, “name”, Wrapper(t)) 2. self.register_parameter(“name”, p) -> setattr(self, “name”, ParamWrapper(p)) 3. model.state_dict() -> StateFunc(model).to_pure_dict() 4. model.load_state_dict(sd) -> UpdateFunc(model, sd) 5. model.parameters() -> StateFunc(model, ParamWrapper).values()
- Decoupling:
The specific wrapper definitions (e.g. flax.nnx.BatchStat or custom.State) must be defined in the Semantic Knowledge Base. Lookups are strict; if no mapping exists for the Abstract Operation (e.g. ‘BatchStat’), the hook aborts and preserves the original code.
Functions¶
|
Transforms register_buffer. |
|
Transforms register_parameter. |
|
Transforms state_dict. |
|
Transforms load_state_dict. |
|
Transforms parameters. |
Module Contents¶
- ml_switcheroo.plugins.state_container.convert_register_buffer(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Transforms register_buffer.
Target: setattr(self, ‘name’, Wrapper(tensor))
Abstract Op Lookup: “BatchStat”
- ml_switcheroo.plugins.state_container.convert_register_parameter(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Transforms register_parameter.
Target: setattr(self, ‘name’, ParamWrapper(param))
Abstract Op Lookup: “Param”
- ml_switcheroo.plugins.state_container.convert_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Transforms state_dict.
Target: StateFunc(model).to_pure_dict()
Abstract Op Lookup: “ModuleState”
- ml_switcheroo.plugins.state_container.convert_load_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Transforms load_state_dict.
Target: UpdateFunc(model, state)
Abstract Op Lookup: “UpdateState”
- ml_switcheroo.plugins.state_container.convert_parameters(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Transforms parameters.
Target: StateFunc(model, ParamType).values()
Abstract Op Lookup: “ModuleState”, “Param”