ml_switcheroo.plugins.state_container

Plugin for handling Stateful Container Logic.

Handles mapping of container management methods between frameworks, particularly impedance mismatches between PyTorch’s imperative register_buffer/parameters system and Flax NNX’s explicit state management.

Mappings (Torch -> JAX/NNX): 1. self.register_buffer(“name”, t) -> setattr(self, “name”, flax.nnx.BatchStat(t)) 2. self.register_parameter(“name”, p) -> setattr(self, “name”, flax.nnx.Param(p)) 3. model.state_dict() -> flax.nnx.state(model).to_pure_dict() 4. model.load_state_dict(sd) -> flax.nnx.update(model, sd) 5. model.parameters() -> flax.nnx.state(model, flax.nnx.Param).values()

Functions

convert_register_buffer(→ libcst.Call)

Transforms self.register_buffer('name', tensor) -> setattr(self, 'name', nnx.BatchStat(tensor)).

convert_register_parameter(→ libcst.Call)

Transforms self.register_parameter('name', param) -> setattr(self, 'name', nnx.Param(param)).

convert_state_dict(→ libcst.Call)

Transforms model.state_dict() -> flax.nnx.state(model).to_pure_dict().

convert_load_state_dict(→ libcst.Call)

Transforms model.load_state_dict(state) -> flax.nnx.update(model, state).

convert_parameters(→ libcst.Call)

Transforms model.parameters() -> flax.nnx.state(model, flax.nnx.Param).values().

Module Contents

ml_switcheroo.plugins.state_container.convert_register_buffer(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Transforms self.register_buffer(‘name’, tensor) -> setattr(self, ‘name’, nnx.BatchStat(tensor)).

ml_switcheroo.plugins.state_container.convert_register_parameter(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Transforms self.register_parameter(‘name’, param) -> setattr(self, ‘name’, nnx.Param(param)).

ml_switcheroo.plugins.state_container.convert_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Transforms model.state_dict() -> flax.nnx.state(model).to_pure_dict().

ml_switcheroo.plugins.state_container.convert_load_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Transforms model.load_state_dict(state) -> flax.nnx.update(model, state).

ml_switcheroo.plugins.state_container.convert_parameters(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call

Transforms model.parameters() -> flax.nnx.state(model, flax.nnx.Param).values().