ml_switcheroo.plugins.rng_threading¶

Plugin for RNG State Threading (The “JAX Pointer” Pattern).

PyTorch handles randomness via global state (torch.manual_seed and generator objects), whereas JAX requires explicit passing and splitting of PRNG keys.

This plugin automates the transition by:

  1. Signature Injection: Adds an rng argument to the function definition.

  2. Preamble Injection: Asks the target Adapter for the correct split syntax (e.g. rng, key = jax.random.split(rng)) and injects it.

  3. Call Rewriting: - Appends key=key to supported stochastic calls. - Strips Torch-specific generator arguments (incompatible with JAX).

Decoupling: Uses traits.requires_explicit_rng logic to determine execution, and calls adapter.get_rng_split_syntax() to determine code generation.

Functions¶

inject_prng_threading(→ libcst.Call)

Plugin Hook: Thread PRNG keys for stochastic operations.

Module Contents¶

ml_switcheroo.plugins.rng_threading.inject_prng_threading(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) → libcst.Call[source]¶

Plugin Hook: Thread PRNG keys for stochastic operations.

Triggers:

Operations marked with requires_plugin: “inject_prng” in the Semantic Knowledge Base. Examples: torch.nn.functional.dropout, torch.randn, torch.bernoulli.

Logic:

Checks ctx.plugin_traits.requires_explicit_rng. If True, applies JAX-style threading. This allows any framework (not just JAX) to opt-in to this behavior via configuration.

Parameters:
  • node – The original CST Call node (e.g., torch.dropout(x, 0.5)).

  • ctx – The HookContext used to request global scope changes (signature/preamble) and read configuration.

Returns:

The transformed CST Call node with the ‘key’ keyword argument appended and ‘generator’ arguments removed.