GetNamedShardingΒΆ

Construct named sharding for distributed computation.

Abstract Signature:

GetNamedSharding(tree, mesh: jax.sharding.Mesh)

JAX (Core)

API: jax.sharding.NamedSharding
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.get_named_sharding
Strategy: Direct Mapping