ml_switcheroo.plugins.rng_threading =================================== .. py:module:: ml_switcheroo.plugins.rng_threading .. autoapi-nested-parse:: Plugin for RNG State Threading (The "JAX Pointer" Pattern). PyTorch handles randomness via global state (`torch.manual_seed` and `generator` objects), whereas JAX requires explicit passing and splitting of PRNG keys. This plugin automates the transition by: 1. **Signature Injection**: Adds an `rng` argument to the function definition. 2. **Preamble Injection**: Adds `rng, key = jax.random.split(rng)` at the start of the function. 3. **Call Rewriting**: - Appends `key=key` to supported stochastic calls. - Strips Torch-specific `generator` arguments (incompatible with JAX). This ensures standard PyTorch calls like `torch.randn(..., generator=g)` become `jax.random.normal(..., key=key)`. **Configuration**: Users can customize the variable naming via `pyproject.toml` or CLI args: - `rng_arg_name`: Name of the argument injected into signature (default: "rng"). - `key_var_name`: Name of the local key variable split from rng (default: "key"). Functions --------- .. autoapisummary:: ml_switcheroo.plugins.rng_threading.inject_prng_threading Module Contents --------------- .. py:function:: inject_prng_threading(node: libcst.Call, ctx: ml_switcheroo.core.hooks.HookContext) -> libcst.Call Plugin Hook: Thread PRNG keys for stochastic operations. Triggers: Operations marked with `requires_plugin: "inject_prng"` in the Semantic Knowledge Base. Examples: `torch.nn.functional.dropout`, `torch.randn`, `torch.bernoulli`. :param node: The original CST Call node (e.g., `torch.dropout(x, 0.5)`). :param ctx: The HookContext used to request global scope changes (signature/preamble) and read configuration. :returns: The transformed CST Call node with the 'key' keyword argument appended and 'generator' arguments removed.