TransformerFeedForwardMoeΒΆ

A sharded MoE Layer.

Abstract Signature:

TransformerFeedForwardMoe(input_dims: int, hidden_dims: int, num_experts: int, num_groups: int)

PyTorch

API: β€”
Strategy: Custom / Partial

PaxML / Praxis

API: praxis.layers.TransformerFeedForwardMoe
Strategy: Direct Mapping