Pmap ==== Parallel map over devices. **Abstract Signature:** ``Pmap(f: Callable)`` .. raw:: html

PyTorch

API:
Strategy: Custom / Partial

JAX (Core)

API: jax.pmap
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.pmap
Strategy: Direct Mapping