ForkRngsΒΆ

Forks the (nested) Rng states of the given node.

Abstract Signature:

ForkRngs(node, split)

PyTorch

API: β€”
Strategy: Custom / Partial

JAX (Core)

API: jax.random.split
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.fork_rngs
Strategy: Direct Mapping