WithPartitioningΒΆ

Annotates an initializer with sharding constraints.

Abstract Signature:

WithPartitioning(initializer: Callable, sharding)

JAX (Core)

API: jax.lax.with_sharding_constraint
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.with_partitioning
Strategy: Direct Mapping