AllToShardedLinearΒΆ

Distributed linear layer where result is sharded across the group.

Abstract Signature:

AllToShardedLinear(in_features: int, out_features: int, bias: bool = True, group)

PyTorch

API: β€”
Strategy: Custom / Partial

Apple MLX

API: mlx.nn.AllToShardedLinear
Strategy: Direct Mapping

Flax NNX

API: β€”
Strategy: Custom / Partial