ShardMapΒΆ

Transformation for sharded parallel execution over a mesh.

Abstract Signature:

ShardMap(f: Callable, mesh)

JAX (Core)

API: jax.experimental.shard_map.shard_map
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.shard_map
Strategy: Direct Mapping