GetPartitionSpec ================ Extracts a PartitionSpec tree from a PyTree. **Abstract Signature:** ``GetPartitionSpec(tree)`` .. raw:: html
jax.sharding.PartitionSpec
flax.nnx.get_partition_spec