MultinomialΒΆ
Returns a tensor where each row contains num_samples indices sampled from the multinomial probability distribution.
Abstract Signature:
Multinomial(input: Tensor, num_samples: int, replacement: bool = False)
JAX (Core)
API:
jax.random.categoricalStrategy: Plugin (inject_prng)