GetPartitionSpecΒΆ

Extracts a PartitionSpec tree from a PyTree.

Abstract Signature:

GetPartitionSpec(tree)

JAX (Core)

API: jax.sharding.PartitionSpec
Strategy: Direct Mapping

Flax NNX

API: flax.nnx.get_partition_spec
Strategy: Direct Mapping