LogicalAxisRulesΒΆ

Context manager for setting logical to mesh axis bindings.

Abstract Signature:

LogicalAxisRules(rules: Sequence)

JAX (Core)

API: jax.sharding.Mesh
Strategy: Plugin (sharding_rules_shim)

Flax NNX

API: flax.nnx.logical_axis_rules
Strategy: Direct Mapping