PmapΒΆ

Parallel map over devices.

Abstract Signature:

Pmap(f: Callable)

PyTorch

API: β€”
Strategy: Custom / Partial

JAX (Core)

API: jax.pmap
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.pmap
Strategy: Direct Mapping