ml_switcheroo.plugins.rng_threading¶
Plugin for RNG State Threading (The âJAX Pointerâ Pattern).
PyTorch handles randomness via global state (torch.manual_seed and generator objects), whereas JAX requires explicit passing and splitting of PRNG keys.
This plugin automates the transition by: 1. Signature Injection: Adds an rng argument to the function definition. 2. Preamble Injection: Adds rng, key = jax.random.split(rng) at the start of the function. 3. Call Rewriting:
Appends key=key to supported stochastic calls.
Strips Torch-specific generator arguments (incompatible with JAX).
This ensures standard PyTorch calls like torch.randn(âŠ, generator=g) become jax.random.normal(âŠ, key=key).
Configuration: Users can customize the variable naming via pyproject.toml or CLI args: - rng_arg_name: Name of the argument injected into signature (default: ârngâ). - key_var_name: Name of the local key variable split from rng (default: âkeyâ).
Functions¶
|
Plugin Hook: Thread PRNG keys for stochastic operations. |
Module Contents¶
- ml_switcheroo.plugins.rng_threading.inject_prng_threading(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) libcst.Call¶
Plugin Hook: Thread PRNG keys for stochastic operations.
- Triggers:
Operations marked with requires_plugin: âinject_prngâ in the Semantic Knowledge Base. Examples: torch.nn.functional.dropout, torch.randn, torch.bernoulli.
- Parameters:
node â The original CST Call node (e.g., torch.dropout(x, 0.5)).
ctx â The HookContext used to request global scope changes (signature/preamble) and read configuration.
- Returns:
The transformed CST Call node with the âkeyâ keyword argument appended and âgeneratorâ arguments removed.