ml_switcheroo.plugins.state_container¶
Plugin for handling Stateful Container Logic.
Handles mapping of container management methods between frameworks, particularly impedance mismatches between PyTorch’s imperative register_buffer/parameters system and Flax NNX’s explicit state management.
Mappings (Torch -> JAX/NNX): 1. self.register_buffer(“name”, t) -> setattr(self, “name”, flax.nnx.BatchStat(t)) 2. self.register_parameter(“name”, p) -> setattr(self, “name”, flax.nnx.Param(p)) 3. model.state_dict() -> flax.nnx.state(model).to_pure_dict() 4. model.load_state_dict(sd) -> flax.nnx.update(model, sd) 5. model.parameters() -> flax.nnx.state(model, flax.nnx.Param).values()
Functions¶
|
Transforms self.register_buffer('name', tensor) -> setattr(self, 'name', nnx.BatchStat(tensor)). |
|
Transforms self.register_parameter('name', param) -> setattr(self, 'name', nnx.Param(param)). |
|
Transforms model.state_dict() -> flax.nnx.state(model).to_pure_dict(). |
|
Transforms model.load_state_dict(state) -> flax.nnx.update(model, state). |
|
Transforms model.parameters() -> flax.nnx.state(model, flax.nnx.Param).values(). |
Module Contents¶
- ml_switcheroo.plugins.state_container.convert_register_buffer(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Transforms self.register_buffer(‘name’, tensor) -> setattr(self, ‘name’, nnx.BatchStat(tensor)).
- ml_switcheroo.plugins.state_container.convert_register_parameter(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Transforms self.register_parameter(‘name’, param) -> setattr(self, ‘name’, nnx.Param(param)).
- ml_switcheroo.plugins.state_container.convert_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Transforms model.state_dict() -> flax.nnx.state(model).to_pure_dict().
- ml_switcheroo.plugins.state_container.convert_load_state_dict(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Transforms model.load_state_dict(state) -> flax.nnx.update(model, state).
- ml_switcheroo.plugins.state_container.convert_parameters(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Transforms model.parameters() -> flax.nnx.state(model, flax.nnx.Param).values().