ml_switcheroo.plugins.rng_threading¶

Plugin for RNG State Threading (The “JAX Pointer” Pattern).

PyTorch handles randomness via global state (torch.manual_seed and generator objects), whereas JAX requires explicit passing and splitting of PRNG keys.

This plugin automates the transition by: 1. Signature Injection: Adds an rng argument to the function definition. 2. Preamble Injection: Adds rng, key = jax.random.split(rng) at the start of the function. 3. Call Rewriting:

  • Appends key=key to supported stochastic calls.

  • Strips Torch-specific generator arguments (incompatible with JAX).

This ensures standard PyTorch calls like torch.randn(
, generator=g) become jax.random.normal(
, key=key).

Configuration: Users can customize the variable naming via pyproject.toml or CLI args: - rng_arg_name: Name of the argument injected into signature (default: “rng”). - key_var_name: Name of the local key variable split from rng (default: “key”).

Functions¶

inject_prng_threading(→ libcst.Call)

Plugin Hook: Thread PRNG keys for stochastic operations.

Module Contents¶

ml_switcheroo.plugins.rng_threading.inject_prng_threading(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call¶

Plugin Hook: Thread PRNG keys for stochastic operations.

Triggers:

Operations marked with requires_plugin: “inject_prng” in the Semantic Knowledge Base. Examples: torch.nn.functional.dropout, torch.randn, torch.bernoulli.

Parameters:
  • node – The original CST Call node (e.g., torch.dropout(x, 0.5)).

  • ctx – The HookContext used to request global scope changes (signature/preamble) and read configuration.

Returns:

The transformed CST Call node with the ‘key’ keyword argument appended and ‘generator’ arguments removed.