ShardedToAllLinearΒΆ

Distributed linear layer.

Abstract Signature:

ShardedToAllLinear(in_features: int, out_features: int)

PyTorch

API: β€”
Strategy: Custom / Partial

Apple MLX

API: mlx.nn.ShardedToAllLinear
Strategy: Direct Mapping