ml_switcheroo.plugins.state_flag_injection¶
Plugin for injecting state flags (e.g., training=True/False) based on context.
This module handles the impedance mismatch between Object-Oriented state management (PyTorch’s model.eval(), model.train()) and Functional statelessness (JAX, Keras functional).
The plugin consists of two cooperating hooks: 1. capture_eval_state: Intercepts model.eval()/train() calls, records the state
change in the HookContext, and removes the imperative call from the AST.
inject_training_flag: Intercepts calls to the model (e.g. model(x)), checks if state was recorded, and injects the generic training=… keyword argument.
State is tracked via a metadata dictionary in HookContext keyed by the object name.
Functions¶
|
Hook: Injects training=True/False kwargs into function calls. |
|
Hook: Intercepts eval()/train() calls to track state removal. |
Module Contents¶
- ml_switcheroo.plugins.state_flag_injection.inject_training_flag_call(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Hook: Injects training=True/False kwargs into function calls.
This hook is triggered on function calls (like model(x) or model.forward(x)) if they map to an abstract operation configured with requires_plugin=”inject_training_flag”.
Logic: 1. Resolve the name of the object being called (the Receiver).
It robustly checks both the full callable name (e.g. self.layer in self.layer(x)) and the parent object (e.g. model in model.forward(x)).
Check ctx.metadata to see if capture_eval_state previously recorded a state.
If state exists, execute the injection of the training argument.
- Parameters:
node – The original CST Call node.
ctx – HookContext containing global metadata state.
- Returns:
The modified Call node with injected arguments, or the original if no state found.
- ml_switcheroo.plugins.state_flag_injection.capture_eval_state(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.CSTNode¶
Hook: Intercepts eval()/train() calls to track state removal.
Action: 1. Identifies the receiver object (model). 2. Determines mode (training=True for .train(), False for .eval()). 3. Updates ctx.metadata with this knowledge. 4. Returns a No-Op node to strip the imperative call from the output code.
- Parameters:
node – The call node (e.g. model.eval()).
ctx – Hook Context.
- Returns:
cst.Name(“None”) effectively replacing the statement with a no-op None, which is valid Python expression statement (does nothing).