AllToShardedLinearΒΆ
Distributed linear layer where result is sharded across the group.
Abstract Signature:
AllToShardedLinear(in_features: int, out_features: int, bias: bool = True, group)
PyTorch
API:
βStrategy: Custom / Partial
Flax NNX
API:
βStrategy: Custom / Partial