ShardMap
Transformation for sharded parallel execution over a mesh.
Abstract Signature:
ShardMap(f: Callable, mesh)
JAX (Core)
API: jax.experimental.shard_map.shard_map
Strategy: Direct Mapping
Flax NNX
API: flax.nnx.shard_map
Strategy: Direct Mapping