x → Embedding → MultiHeadAttention → Concat → Project to lower dim →
→ Add(x) → LayerNorm → FFN → Add → LayerNorm
Vocab to embedding
torch.nn.embedding(Vocab, embed_dim)
Batch X Seq Len X Vocab → Batch X Seq Len X embed_dim
PE = Batch X Seq Len X embed_dim