ml_switcheroo.plugins.state_container

Plugin for handling Stateful Container Logic.

Handles mapping of container management methods between frameworks, particularly impedance mismatches between imperative usage (PyTorch’s register_buffer) and functional explicit state management (JAX/Flax/MLX).

This module converts: 1. self.register_buffer(“name”, t) -> setattr(self, “name”, Wrapper(t)) 2. self.register_parameter(“name”, p) -> setattr(self, “name”, ParamWrapper(p)) 3. model.state_dict() -> StateFunc(model).to_pure_dict() 4. model.load_state_dict(sd) -> UpdateFunc(model, sd) 5. model.parameters() -> StateFunc(model, ParamWrapper).values()

Decoupling:

The specific wrapper definitions (e.g. flax.nnx.BatchStat or custom.State) must be defined in the Semantic Knowledge Base. Lookups are strict; if no mapping exists for the Abstract Operation (e.g. ‘BatchStat’), the hook aborts and preserves the original code.

Functions

convert_register_buffer(→ libcst.Call)

Transforms register_buffer.

convert_register_parameter(→ libcst.Call)

Transforms register_parameter.

convert_state_dict(→ libcst.Call)

Transforms state_dict.

convert_load_state_dict(→ libcst.Call)

Transforms load_state_dict.

convert_parameters(→ libcst.Call)

Transforms parameters.

Module Contents

ml_switcheroo.plugins.state_container.convert_register_buffer(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]

Transforms register_buffer.

Target: setattr(self, ‘name’, Wrapper(tensor))

Abstract Op Lookup: “BatchStat”

ml_switcheroo.plugins.state_container.convert_register_parameter(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]

Transforms register_parameter.

Target: setattr(self, ‘name’, ParamWrapper(param))

Abstract Op Lookup: “Param”

ml_switcheroo.plugins.state_container.convert_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]

Transforms state_dict.

Target: StateFunc(model).to_pure_dict()

Abstract Op Lookup: “ModuleState”

ml_switcheroo.plugins.state_container.convert_load_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]

Transforms load_state_dict.

Target: UpdateFunc(model, state)

Abstract Op Lookup: “UpdateState”

ml_switcheroo.plugins.state_container.convert_parameters(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]

Transforms parameters.

Target: StateFunc(model, ParamType).values()

Abstract Op Lookup: “ModuleState”, “Param”