ml_switcheroo.plugins.rng_threading¶
Plugin for RNG State Threading (The âJAX Pointerâ Pattern).
PyTorch handles randomness via global state (torch.manual_seed and generator objects), whereas JAX requires explicit passing and splitting of PRNG keys.
This plugin automates the transition by:
Signature Injection: Adds an rng argument to the function definition.
Preamble Injection: Asks the target Adapter for the correct split syntax (e.g. rng, key = jax.random.split(rng)) and injects it.
Call Rewriting: - Appends key=key to supported stochastic calls. - Strips Torch-specific generator arguments (incompatible with JAX).
Decoupling: Uses traits.requires_explicit_rng logic to determine execution, and calls adapter.get_rng_split_syntax() to determine code generation.
Functions¶
|
Plugin Hook: Thread PRNG keys for stochastic operations. |
Module Contents¶
- ml_switcheroo.plugins.rng_threading.inject_prng_threading(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call[source]¶
Plugin Hook: Thread PRNG keys for stochastic operations.
- Triggers:
Operations marked with requires_plugin: âinject_prngâ in the Semantic Knowledge Base. Examples: torch.nn.functional.dropout, torch.randn, torch.bernoulli.
- Logic:
Checks ctx.plugin_traits.requires_explicit_rng. If True, applies JAX-style threading. This allows any framework (not just JAX) to opt-in to this behavior via configuration.
- Parameters:
node â The original CST Call node (e.g., torch.dropout(x, 0.5)).
ctx â The HookContext used to request global scope changes (signature/preamble) and read configuration.
- Returns:
The transformed CST Call node with the âkeyâ keyword argument appended and âgeneratorâ arguments removed.