Transformer Coding Details
A Simple Implementation
In this article, I’d like to discuss the Transformer coding details of my implementation based on the Transformer architecture from Attention Is All You Need by Ashish Vaswani et al. There are some implementations already out there. I have listed some of them in the references section. There are tricks not written in the paper that I came to learn while reading those source codes.
In what follows, I summarized the coding details for each component using PyTorch. Ultimately, I assembled all the modules into a model based on the Transformer architecture.
1 Embedding Layer
An embedding layer converts token indices into embedding vectors. Once tokens are in vector format, we can perform dot-product and other linear algebra operations.
Since PyTorch has nn.Embedding
for this exact purpose, we can use that.
import math
import torch.nn as nn
from torch import Tensor
class Embedding(nn.Module):
def __init__(self, vocab_size: int, dim_embed: int) -> None:
super().__init__()
self.embedding = nn.Embedding(vocab_size, dim_embed)
self.sqrt_dim_embed = math.sqrt(dim_embed)
def forward(self, x: Tensor) -> Tensor:
= self.embedding(x.long())
x = x * self.sqrt_dim_embed
x return x
The paper mentions that they multiply embedding values by the square root of the embedding dimension dim_embed
.
The embedding layer will transform the shape of an input batch from (batch_size, max_sequence_length)
to (batch_size, max_sequence_length, dim_embed)
.
Note: we calculate max_sequence_length
per batch. Then, we add padding to shorter sentences. The attention mechanism will ignore padded positions using a mask on this later.
I’ve written an article on word embedding, which you can read here to learn more about.
2 Positional Encoding
The idea of positional encoding is to encode token position information into embedding vectors so that we don’t need recurrence to handle sequences.
We calculate positional encoding values according to the below formula for each position pos
and embedding vector dimensions 2i
and 2i+1
for i=0...d_model//2-1
.
\[ \begin{aligned} \text{PE}(\text{pos}, 2i) &= \sin\left(\dfrac{\text{pos}}{10000^{\frac{2i}{d_\text{model}}}}\right) \\ \text{PE}(\text{pos}, 2i+1) &= \cos\left(\dfrac{\text{pos}}{10000^{\frac{2i}{d_\text{model}}}}\right) \end{aligned} \]
An example code to calculate positional encoding values is as follows:
import math
import torch
= torch.zeros(max_positions, dim_embed)
pe
for pos in range(max_positions):
for i in range(0, dim_embed, 2):
= pos / (10000 ** (i / dim_embed))
theta = math.sin(theta)
pe[pos, i ] + 1] = math.cos(theta) pe[pos, i
However, the double-loops run pretty slow when the maximum number of positions is large. So, we should take advantage of parallelized vector operations. PyTorch’s tutorial article inspires the below code:
import math
import torch
import torch.nn as nn
from torch import Tensor
class PositionalEncoding(nn.Module):
def __init__(self, max_positions: int, dim_embed: int, drop_prob: float) -> None:
super().__init__()
assert dim_embed % 2 == 0
# Inspired by https://pytorch.org/tutorials/beginner/transformer_tutorial.html
= torch.arange(max_positions).unsqueeze(1)
position = torch.arange(0, dim_embed, 2)
dim_pair = torch.exp(dim_pair * (-math.log(10000.0) / dim_embed))
div_term
= torch.zeros(max_positions, dim_embed)
pe 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe[:,
# Add a batch dimension: (1, max_positions, dim_embed)
= pe.unsqueeze(0)
pe
# Register as non-learnable parameters
self.register_buffer('pe', pe)
self.dropout = nn.Dropout(p=drop_prob)
def forward(self, x: Tensor) -> Tensor:
# Max sequence length within the current batch
= x.size(1)
max_sequence_length
# Add positional encoding up to the max sequence length
= x + self.pe[:, :max_sequence_length]
x = self.dropout(x)
x return x
Line 15: div_term = torch.exp(dim_pair * (-math.log(10000.0) / dim_embed))
might look complicated. We can spell it out mathematically:
\[ \begin{aligned} \exp\left(-2i \log(10000) / d_{\text{model}} \right) &= \exp\left(-2i / d_\text{model} \log(10000) \right) \\ &= \exp(\log(10000)^{-2i/d_\text{model}}) \\ &= \dfrac{1}{10000^\frac{2i}{d_\text{model}}} \end{aligned} \]
In the end, we apply dropout
as the paper says:
we apply dropout to the sums of the embeddings and the positional encodings in both the encoder and decoder stacks. Section 5.4 Regularization from the paper
I’ve written an article on positional encoding, which you can read here for more details.
3 Scaled Dot-Product Attention
Scaled dot-product attention uses the dot-product of two embedding vectors to see how related they are. “Scaled” indicates the part where we divide the dot-product results by the square root of the number of embedding vector dimensions, which they did to reduce the magnitude of dot-products. Please see section 3.2.1 of the paper for details.
We use matrix multiplication to handle dot-product calculation.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
def attention(query: Tensor, key: Tensor, value: Tensor, mask: Tensor=None) -> Tensor:
= query.shape[-1]**0.5
sqrt_dim_head
= torch.matmul(query, key.transpose(-2, -1))
scores = scores / sqrt_dim_head
scores
if mask is not None:
= scores.masked_fill(mask==0, -1e9)
scores
= F.softmax(scores, dim=-1)
weight return torch.matmul(weight, value)
4 Self-Attention and Padding Mask
Conceptually, self-attention produces an attention matrix, as shown below.
We apply softmax for each row to get attention weights. The red color indicates strong relevance.
We deal with a batch of sequences during training and need to add padding to shorter sequences. However, we do not want them for attention calculation.
An attention mask has zeros for padding positions after the end of a sequence.
['Hello', ',','world', '!'] => [1, 1, 1, 1, 0, 0]
We give a large negative value to masked positions so that the softmax will assign zero probability to them. So, it would look like the below.
The masked positions will not contribute to the final attention values.
5 Target-Source Attention and Padding Mask
Similarly, target-source attention produces attention weights with a padding mask.
Note: the attention weight matrix may or may not be a square depending on the maximum sequence lengths of query Q
and key K
.
6 Subsequent Mask for Decoder Input
We also use a mask to hide subsequent positions in the decoder input.
[<SOS>, 'Bonjour', 'le', 'monde', '!']
<SOS> => [1, 0, 0, 0, 0]
'Bonjour' => [1, 1, 0, 0, 0]
'le' => [1, 1, 1, 0, 0]
'monde' => [1, 1, 1, 1, 0]
'!' => [1, 1, 1, 1, 1]
For example, attention calculation for the position of 'Bonjour'
should not use scores from the positions of 'le', 'monde', '!'
. Hence, the subsequent attention mask is [1, 1, 0, 0, 0]
.
So a mask for the decoder is a combination of both a padding mask and a subsequent mask. A subsequent mask hides the upper triangle, excluding the diagonals.
It creates an effect as though the part of inputs to the decoder is invisible.
When calculating attention values for the first position, we ignore the second position and after. When calculating attention values for the second position, we ignore the third position and the rest. And so on.
def create_masks(src_batch: Tensor, tgt_batch: Tensor) -> Tuple[Tensor, Tensor]:
# ----------------------------------------------------------------------
# [1] padding mask
# ----------------------------------------------------------------------
# (batch_size, 1, max_tgt_seq_len)
= (src_batch != PAD_IDX).unsqueeze(1)
src_pad_mask
# (batch_size, 1, max_src_seq_len)
= (tgt_batch != PAD_IDX).unsqueeze(1)
tgt_pad_mask
# ----------------------------------------------------------------------
# [2] subsequent mask for decoder inputs
# ----------------------------------------------------------------------
= tgt_batch.shape[1]
max_tgt_sequence_length = (max_tgt_sequence_length, max_tgt_sequence_length)
tgt_attention_square
# full attention
= torch.full(tgt_attention_square, 1)
full_mask
# subsequent sequence should be invisible to each token position
= torch.tril(full_mask)
subsequent_mask
# add a batch dim (1, max_tgt_seq_len, max_tgt_seq_len)
= subsequent_mask.unsqueeze(0)
subsequent_mask
return src_pad_mask, tgt_pad_mask & subsequent_mask
I have written an article on the self-attention mechanism, which you can read this article for more details.
7 Multi-Head Attention
Multi-head attention employs multiple scaled-dot attention calculations to capture various relationships between words.
The below figure from the paper visualizes outputs from two attention heads:
The encoder uses it for the self-attention mechanism. The decoder uses it for self-attention mechanism and target-source attention, where it calculates attention scores between encoder outputs (features) and decoder embedding vectors, which are inputs (shifted-right outputs) after going through the masked multi-head attention.
We use the term masked multi-head attention because the decoder input comes with subsequent masks that effectively hide the future positions (subsequent sequence of each token). So, for the decoder, mask
combines padding and subsequent masks.
Three arrows go into multi-head attention for value V
, key K
, and query Q
, respectively.
Regarding self-attention, V, K, Q
are all the same embedding vectors. However, V
and K
are the encoder outputs (features) for target-source attention, and Q
is the decoder embedding vectors.
Conceptually, we independently perform multiple scaled dot-product attention calculations, one for each head. For that, we first apply eight separate linear operations on V, K, Q
so that attention values will capture different relationships for each attention head.
In actual implementation, we can perform one linear operation for each of V, K, Q
and reshape them into multiple heads as shown in the code below:
# continuation of attention.py
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads: int, dim_embed: int, drop_prob: float) -> None:
super().__init__()
assert dim_embed % num_heads == 0
self.num_heads = num_heads
self.dim_embed = dim_embed
self.dim_head = dim_embed // num_heads
self.query = nn.Linear(dim_embed, dim_embed)
self.key = nn.Linear(dim_embed, dim_embed)
self.value = nn.Linear(dim_embed, dim_embed)
self.output = nn.Linear(dim_embed, dim_embed)
self.dropout = nn.Dropout(drop_prob)
def forward(self, x: Tensor, y: Tensor, mask: Tensor=None) -> Tensor:
= self.query(x)
query = self.key (y)
key = self.value(y)
value
= x.size(0)
batch_size = query.view(batch_size, -1, self.num_heads, self.dim_head)
query = key .view(batch_size, -1, self.num_heads, self.dim_head)
key = value.view(batch_size, -1, self.num_heads, self.dim_head)
value
# Into the number of heads (batch_size, num_heads, -1, dim_head)
= query.transpose(1, 2)
query = key .transpose(1, 2)
key = value.transpose(1, 2)
value
if mask is not None:
= mask.unsqueeze(1)
mask
= attention(query, key, value, mask)
attn = attn.transpose(1, 2).contiguous().view(batch_size, -1, self.dim_embed)
attn
= self.dropout(self.output(attn))
out
return out
The shape of Q, K, V
is changed from (batch_size, max_sequence_length, dim_embed)
to (batch_size, num_heads, max_sequence_length, dim_head)
where dim_head = dim_embed // num_head
.
For example, if dim_embed = 512
and num_heads = 8
, dim_head = 64
.
Also, note that we determine max_sequence_length
per batch.
We restore the original shape by reshaping the outputs from the scaled dot-product attention to (batch_size, max_sequence_length, dim_embed)
. So, the embedding vectors keep the same number of dimensions throughout the process.
We give the mask one extra dimension to be broadcastable across multiple heads. Then, masked_fill
in the scaled dot-product attention will perform the attention calculation independently for each head using the same mask.
In the end, we apply dropout
as the paper says:
We apply dropout [33] to the output of each sub-layer, before it is added to the sub-layer input and normalized. Section 5.4 Regularization from the paper
8 Position-wise Feed-Forward
The position-wise feed-forward layers inject non-linearity into embedding vectors.
Embedding vectors have the shape of (batch_size, max_sequence_length, dim_embed)
. As we do not flatten embedding vectors, the linear operations are applied to each position independently and identically, which is why it is called position-wise feed-forward.
import torch.nn as nn
from torch import Tensor
class PositionwiseFeedForward(nn.Module):
def __init__(self, dim_embed: int, dim_pffn: int, drop_prob: float) -> None:
super().__init__()
self.pffn = nn.Sequential(
nn.Linear(dim_embed, dim_pffn),=True),
nn.ReLU(inplace
nn.Dropout(drop_prob),
nn.Linear(dim_pffn, dim_embed),
nn.Dropout(drop_prob),
)
def forward(self, x: Tensor) -> Tensor:
return self.pffn(x)
The first linear operation expands the dimensions. I understand that doing so means that ReLU will not lose too much information.
The second linear operation restores the original dimensions. So, we can continue with the same embedding vector shape to perform the remaining process.
For example, in the base Transformer model, the embedding vector dimensions are increased from 512 to 2048 and restored from 2048 to 512.
In the end, we apply dropout
as the paper says:
We apply dropout [33] to the output of each sub-layer, before it is added to the sub-layer input and normalized. Section 5.4 Regularization from the paper
The paper does not mention dropout
between two linear layers. However, the reference implementations have it, which makes sense to me, so I decided to do the same.
9 Encoder
The encoder uses multiple encoder blocks.
import torch.nn as nn
from torch import Tensor
from .attention import MultiHeadAttention
from .feed_forward import PositionwiseFeedForward
class Encoder(nn.Module):
def __init__(self,
int,
num_blocks: int,
num_heads: int,
dim_embed: int,
dim_pffn: float) -> None:
drop_prob: super().__init__()
self.blocks = nn.ModuleList(
[EncoderBlock(num_heads, dim_embed, dim_pffn, drop_prob)for _ in range(num_blocks)]
)self.layer_norm = nn.LayerNorm(dim_embed)
def forward(self, x: Tensor, x_mask: Tensor):
for block in self.blocks:
= block(x, x_mask)
x = self.layer_norm(x)
x return x
Ultimately, it applies layer normalization before passing features to the decoder.
10 Encoder Block
An encoder block uses:
- Multi-head attention for self-attention
- Position-wise feed-forward to inject non-linearity
We have a residual connection and a layer normalization between those sub-layers.
# continuation of encoder.py
class EncoderBlock(nn.Module):
def __init__(self,
int,
num_heads: int,
dim_embed: int,
dim_pwff: float) -> None:
drop_prob: super().__init__()
# Self-attention
self.self_atten = MultiHeadAttention(num_heads, dim_embed, drop_prob)
self.layer_norm1 = nn.LayerNorm(dim_embed)
# Point-wise feed-forward
self.feed_forward = PositionwiseFeedForward(dim_embed, dim_pwff, drop_prob)
self.layer_norm2 = nn.LayerNorm(dim_embed)
def forward(self, x: Tensor, x_mask: Tensor) -> Tensor:
= x + self.sub_layer1(x, x_mask)
x = x + self.sub_layer2(x)
x return x
def sub_layer1(self, x: Tensor, x_mask: Tensor) -> Tensor:
= self.layer_norm1(x)
x = self.self_atten(x, x, x_mask)
x return x
def sub_layer2(self, x: Tensor) -> Tensor:
= self.layer_norm2(x)
x = self.feed_forward(x)
x return x
In the paper, they add a residual connection to a sub-layer (multi-head attention or position-wise feed-forward) and then followed by a layer normalization: LayerNorm(x + Sublayer(x))
.
However, the reference implementations apply layer normalization before a sub-layer: x + Sublayer(LayerNorm(x))
.
According to this paper and this paper, layer normalization before a sub-layer is more stable. So, I took the same approach.
11 Decoder
The decoder uses multiple decoder blocks.
import torch.nn as nn
from torch import Tensor
from .attention import MultiHeadAttention
from .feed_forward import PositionwiseFeedForward
class Decoder(nn.Module):
def __init__(self,
int,
num_blocks: int,
num_heads: int,
dim_embed: int,
dim_pffn: float) -> None:
drop_prob: super().__init__()
self.blocks = nn.ModuleList(
[DecoderBlock(num_heads, dim_embed, dim_pffn, drop_prob)for _ in range(num_blocks)]
)self.layer_norm = nn.LayerNorm(dim_embed)
def forward(self, x: Tensor, x_mask: Tensor, y: Tensor, y_mask: Tensor) -> Tensor:
for block in self.blocks:
= block(y, y_mask, x, x_mask)
y = self.layer_norm(y)
y return y
Ultimately, it applies layer normalization before giving outputs to the final linear (projection) layer.
12 Decoder Block
A decoder block uses:
- Masked multi-head attention for self-attention
- Multi-head attention for target-source attention
- Position-wise feed-forward to inject non-linearity
In between those sub-layers, we have a residual connection and layer normalization.
Like the encoder, the reference implementations apply layer normalization: x + Sublayer(LayerNorm(x))
. So, I did the same here, too.
# continuation of decoder.py
class DecoderBlock(nn.Module):
def __init__(self,
int,
num_heads: int,
dim_embed: int,
dim_pwff: float) -> None:
drop_prob: super().__init__()
# Self-attention
self.self_attn = MultiHeadAttention(num_heads, dim_embed, drop_prob)
self.layer_norm1 = nn.LayerNorm(dim_embed)
# Target-source
self.tgt_src_attn = MultiHeadAttention(num_heads, dim_embed, drop_prob)
self.layer_norm2 = nn.LayerNorm(dim_embed)
# Position-wise
self.feed_forward = PositionwiseFeedForward(dim_embed, dim_pwff, drop_prob)
self.layer_norm3 = nn.LayerNorm(dim_embed)
def forward(self, y, y_mask, x, x_mask) -> Tensor:
= y + self.sub_layer1(y, y_mask)
y = y + self.sub_layer2(y, x, x_mask)
y = y + self.sub_layer3(y)
y return y
def sub_layer1(self, y: Tensor, y_mask: Tensor) -> Tensor:
= self.layer_norm1(y)
y = self.self_attn(y, y, y_mask)
y return y
def sub_layer2(self, y: Tensor, x: Tensor, x_mask: Tensor) -> Tensor:
= self.layer_norm2(y)
y = self.tgt_src_attn(y, x, x_mask)
y return y
def sub_layer3(self, y: Tensor) -> Tensor:
= self.layer_norm3(y)
y = self.feed_forward(y)
y return y
13 Transformer
The Transformer glues all modules together into one model. Unlike the above diagram, the implementation applies layer normalization before each sub-layer.
import torch.nn as nn
from torch import Tensor
from ..modules import Embedding, PositionalEncoding, Encoder, Decoder
class Transformer(nn.Module):
def __init__(self,
int,
input_vocab_size: int,
output_vocab_size: int,
max_positions: int,
num_blocks: int,
num_heads: int,
dim_embed: int,
dim_pffn: float) -> None:
drop_prob: super().__init__()
# Input embeddings, positional encoding, and encoder
self.input_embedding = Embedding(input_vocab_size, dim_embed)
self.input_pos_encoding = PositionalEncoding(
max_positions, dim_embed, drop_prob)self.encoder = Encoder(num_blocks, num_heads, dim_embed, dim_pffn, drop_prob)
# Output embeddings, positional encoding, decoder, and projection
# to vocab size dimension
self.output_embedding = Embedding(output_vocab_size, dim_embed)
self.output_pos_encoding = PositionalEncoding(
max_positions, dim_embed, drop_prob)self.decoder = Decoder(num_blocks, num_heads, dim_embed, dim_pffn, drop_prob)
self.projection = nn.Linear(dim_embed, output_vocab_size)
# Initialize parameters
for param in self.parameters():
if param.dim() > 1:
nn.init.xavier_uniform_(param)
def forward(self, x: Tensor, y: Tensor,
=None, y_mask: Tensor=None) -> Tensor:
x_mask: Tensor= self.encode(x, x_mask)
x = self.decode(x, y, x_mask, y_mask)
y return y
def encode(self, x: Tensor, x_mask: Tensor=None) -> Tensor:
= self.input_embedding(x)
x = self.input_pos_encoding(x)
x = self.encoder(x, x_mask)
x return x
def decode(self, x: Tensor, y: Tensor,
=None, y_mask: Tensor=None) -> Tensor:
x_mask: Tensor= self.output_embedding(y)
y = self.output_pos_encoding(y)
y = self.decoder(x, x_mask, y, y_mask)
y return self.projection(y)
The last linear layer projects the embedding dimensions to the number of vocabulary. We can apply softmax on the projected values to assign a probability to each token index.
Please look at this article for more details of the encoder-decoder architecture.
14 Greedy Translator
We don’t always need to use Softmax. If all we want to know is the most probable token, we can use argmax
instead, as shown in the diagram below.
A Greedy Translator would do the following:
- The encoder extracts features from an input language sentence.
- The first decoder input is the
SOS
(the start-of-sentence token). - The decoder outputs the last projection.
- The most probable token joins the decoder input sequence at the end.
- Steps 3 and 4 repeat.
- Finish when the most probable next token is
EOS
(end-of-sentence).
# Greedy Translator Pseudo Code
# Input sentence
= "Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche."
input_text
# Convert input text tokens into token indices
= ...tokenize input_text into token indices...
encoder_input
# Encoder output
= model.encode(encoder_input)
features
# The first input
= torch.LongTensor([[SOS_IDX]])
decoder_input
for _ in range(max_output_length):
# Projection (1, current_sequence_length, vocab_size)
= model.decode(features, decoder_inputs)
projections
# Greedy selection of the last token
= torch.argmax(projections[:, -1], keepdim=True)
token_index
# Exit if EOS
if token_index.item()==EOS_IDX:
break
# Add the most probable token to the decoder input
= torch.cat([decoder_input, token_index], dim=1)
decoder_input
# Remove SOS from the token indices
= decoder_input[0, 1:]
decoder_output
# Convert token indices to text tokens
= ...text indices => text tokens => a sentence text...
output_text print(output_text)
# Expected: Two young, White males are outside near many bushes.
15 References
- The Annotated Transformer
Harvard NLP - How to code The Transformer in Pytorch
Samuel Lynn-Evans - The Illustrated Transformer
Jay Alammar - Transformer Architecture: The Positional Encoding
Amirhossein Kazemnejad - Tensor2Tensor
TensorFlow - PyTorch Transformer
PyTorch - Language Modeling with nn.Transformer and Torchtext
PyTorch - On Layer Normalization in the Transformer Architecture
Ruibin Xiong, Yunchang Yang, Di He, Kai Zheng, Shuxin Zheng, Chen Xing, Huishuai Zhang, Yanyan Lan, Liwei Wang, Tie-Yan Liu - Transformers without Tears: Improving the Normalization of Self-Attention
Toan Q. Nguyen, Julian Salazar
GitHub
Presentation