CausalMaskΒΆ

Computes and returns causal mask.

Abstract Signature:

CausalMask(input_t: Tensor)

PyTorch

API: torch.nn.Transformer.generate_square_subsequent_mask
Strategy: Direct Mapping

PaxML / Praxis

API: praxis.layers.causal_mask
Strategy: Direct Mapping