Transformer’s Evaluation Details

Greedy and Beam Search Translators

Transformer
Machine Translation
Published

January 30, 2022

Transformer’s Evaluation Details: Greedy and Beam Search Translators

In the previous article, we discussed the training details for my implementation of the Transformer architecture from Attention Is All You Need by Ashish Vaswani et al.

This article is the last of the series. We discuss the evaluation details:

1 Encoder Features Generation

Before discussing greedy or beam search translators, we need to use an encoder to extract features from input sentences like the original Transformer’s encoder.

First, we tokenize an input sentence. I’ve explained the process in this article, where I built a Vocab class.

source_vocab = Vocab(...)
target_vocab = Vocab(...)

Suppose we are translating from German (source) to English (target).

For example, an input sentence is as follows:

input_sentence = 'Der braune Hund steht auf dem Sandstrand.'

We tokenize it into a list of numbers:

input_tokens = source_vocab(input_sentence)

print(input_tokens)

Below is the output using the Multi30k ('de', 'en') train dataset.

[88, 530, 33, 30, 11, 26, 1357, 4]

Then, we convert input_tokens as an input tensor to the encoder.

# A batch of one input for Encoder
encoder_input = torch.Tensor([input_tokens])

Finally, we generate encoder features encoder_output:

model.eval()

with torch.no_grad():
    encoder_output = model.encode(encoder_input)

Please look at this article for the details of the encode method.

We use the features generated by the encoder as part of decoder inputs.

2 Greedy Translator

A translator is a decoder that processes the features from the encoder into target language sequences. I built such a decoder based on the original Transformer’s decoder.

We feed SOS_IDX (start-of-sequence) token to the decoder to initiate a translation process:

# Start with SOS
decoder_input = torch.Tensor([[SOS_IDX]]).long()

Note: in my Vocab implementation, SOS_IDX is defined as 2.

So, the first decoder input is a batch of one input, which is [ [ SOS_IDX ] ].

We don’t know the length of translated sequences. So, we define a maximum output size as input length plus an extra 50.

# Maximum output size
max_output_length = encoder_input.shape[-1] + 50

We feed both encoder_output and decoder_input to the decoder.

# Autoregressive
for _ in range(max_output_length):
    # Decoder prediction
    logits = model.decode(encoder_output, decoder_input)</pre>

    ....

For the decode method details, please look at this article.

logits has values for all possible token indices from the target vocabulary. A token index with a bigger value is more probable than other tokens with smaller values.

Greedy decoding selects the most probable token for the next iteration.

    # Greedy selection
    token_index = torch.argmax(logits[:, -1], keepdim=True)

If the token_index is EOS_IDX (end-of-sequence), we exit the loop and complete the translation.

    # EOS is most probable => Exit
    if token_index.item()==EOS_IDX:
        break

Otherwise, we append the token_index to the decoder_input and continue to the next iteration.

    # Next Input to Decoder
    decoder_input = torch.cat([decoder_input, token_index], dim=1)

When completing the translation, we remove the SOS_IDX:

# Exclude SOS at the beginning.
decoder_output = decoder_input[0, 1:].numpy()

Finally, we convert token indices to text tokens:

# Convert token indices to token texts
output_texts = [target_vocab.tokens[i] for i in decoder_output]

So, the greedy translator code would look like the below:

import torch
from torch import Tensor

# Create source and target vocab objects
source_vocab = ...
target_vocab = ...

# Input sentence
input_text = '....input language sentence...'

# Tokenization
input_tokens = source_vocab(input_text.strip())

# A batch of one input for Encoder
encoder_input = torch.Tensor([input_tokens])

# Generate encoded features
model.eval()
with torch.no_grad():
    encoder_output = model.encode(encoder_input)

# Start decoding with SOS
decoder_input = torch.Tensor([[SOS_IDX]]).long()

# Maximum output size
max_output_length = encoder_input.shape[-1] + 50 # give some extra length

# Autoregressive
for _ in range(max_output_length):
    # Decoder prediction
    logits = model.decode(encoder_output, decoder_input)

    # Greedy selection
    token_index = torch.argmax(logits[:, -1], keepdim=True)
    
    # EOS is most probable => Exit
    if token_index.item()==EOS_IDX:
        break

    # Next Input to Decoder
    decoder_input = torch.cat([decoder_input, token_index], dim=1)

# Decoder input is a batch of one entry, 
# and we also exclude SOS at the beginning.
decoder_output = decoder_input[0, 1:].numpy()

# Convert token indices to token texts
output_texts = [target_vocab.tokens[i] for i in decoder_output]

3 Beam Search Translator

The beam search translator follows the same process as the greedy translator, except that we keep track of multiple translation sequences (paths).

Please look at this article for more details on the beam search algorithm.

We call the number of paths beam_size:

beam_size = 3

Like the greedy translator, we start with one sequence with only SOS_IDX. We also define scores, having only one score of 0 for the start sequence.

# Start with SOS
decoder_input = torch.Tensor([[SOS_IDX]]).long()
scores = torch.Tensor([0.])

We feed both encoder_output and decoder_input to the decoder.

for i in range(max_output_length):
    # Decoder prediction
    logits = model.decode(encoder_output, decoder_input)

Unlike the greedy translator, we calculate log_softmax to add to scores:

   # Softmax
    log_probs = torch.log_softmax(logits[:, -1], dim=1)

log_softmax has the range of [-inf, 0] since softmax probability has the range of [0, 1].

We add this value to the score of the sequence. However, it means a longer sequence will have more values added. So, we apply a penalty for the sequence length as follows:

def sequence_length_penalty(length: int, alpha: float=0.6) -> float:
    return ((5 + length) / (5 + 1)) ** alpha

The details of the sequence length penalty are in this paper.

The penalty gets bigger when the sequence length becomes longer. As such, the additional score gets smaller.

    log_probs = log_probs / sequence_length_penalty(i+1, alpha)

We set log_probs to zero for paths that have already reached EOS_IDX so we don’t increase path scores.

    # Set score to zero where EOS has been reached
    paths_EOS_reached = decoder_input[:, -1]==EOS_IDX

    log_probs[paths_EOS_reached, : ] = 0

Now, we can add log_probs to scores.

    scores = scores.unsqueeze(1) + log_probs

Note: scores has the shape (beam_size,), whereas log_probs has the shape (beam_size, vocab_size). So, we add an extra dimension to scores by unsqueeze. The resulting scores has the shape (beam_size, vocab_size).

We now have scores for all token indices per beam path. We need to select the top scores for the beam size. We flatten scores and select the top scores:

    scores, indices = torch.topk(scores.reshape(-1), beam_size)

When beam_size = 3, we have the top 3 scores and indices, which we divide by vocab_size to obtain the beam path indices:

    beam_indices = torch.divide(indices, vocab_size,  
                                rounding_mode='floor')

Note: vocab_size = len(target_vocab).

We also calculate the remainder of the division to obtain the token indices:

    token_indices = torch.remainder(indices, vocab_size)

We iterate through pairs of (beam index, token index) to compose the successive decoder inputs:

    next_decoder_input = []

    for beam_index, token_index in zip(beam_indices, token_indices):
        prev_decoder_input = decoder_input[beam_index]

        if prev_decoder_input[-1]==EOS_IDX:
            token_index = EOS_IDX # once EOS, always EOS

        token_index = torch.LongTensor([token_index])
        next_decoder_input.append(
            torch.cat([prev_decoder_input, token_index])
        )

    decoder_input = torch.vstack(next_decoder_input)

Note: For a path that has already reached EOS_IDX, we ensure the next token is also EOS_IDX so that the path’s score will remain the same.

If all beam paths have EOS_IDX, we exit the loop:

    if (decoder_input[:, -1]==EOS_IDX).sum() == beam_size:
        break

In the first iteration, decoder_input had only one input. But from the second iteration, the number of inputs to the decoder becomes the beam size. So, we expand encoder_output as follows:

    if i==0:
        encoder_output = encoder_output.expand(
                             beam_size, 
                             *encoder_output.shape[1:])

In other words, encoder_output changes from a batch of one input to three identical inputs. We do this expansion only once since we keep track of the same number of paths after that.

When the loop exits, we choose the best paths based on scores:

decoder_output, _ = max(zip(decoder_input, scores),
                        key=lambda x: x[1])

Then, we remove SOS_IDX:

decoder_output = decoder_output[1:].numpy() # remove SOS

Finally, we convert token indices to text tokens:

output_texts = [target_vocab.tokens[i] for i in decoder_output \
                                       if i != EOS_IDX]

Note: we exclude EOS_IDX as the loop exits only when all beam paths have EOS_IDX.

So, the beam search translator code would look like the below:

import torch
from torch import Tensor

# Create source and target vocab objects
source_vocab = ...
target_vocab = ...

# Beam size and penalty alpha
beam_size = 3
alpha = 0.6

# Input sentence
input_text = '....input language sentence...'

# Tokenization
input_tokens = source_vocab(input_text.strip())

# A batch of one input for Encoder
encoder_input = torch.Tensor([input_tokens])

# Generate encoded features
model.eval()
with torch.no_grad():
    encoder_output = model.encode(encoder_input)

# Start with SOS
decoder_input = torch.Tensor([[SOS_IDX]]).long()

# Maximum output size
max_output_length = encoder_input.shape[-1] + 50 # give some extra length

scores = torch.Tensor([0.])
vocab_size = len(target_vocab)

for i in range(max_output_length):
    # Decoder prediction
    logits = model.decode(encoder_output, decoder_input)

    # Softmax
    log_probs = torch.log_softmax(logits[:, -1], dim=1)
    log_probs = log_probs / sequence_length_penalty(i+1, alpha)

    # Set score to zero where EOS has been reached
    log_probs[decoder_input[:, -1]==EOS_IDX, :] = 0
                                         
    # scores [beam_size, 1], log_probs [beam_size, vocab_size]
    scores = scores.unsqueeze(1) + log_probs

    # Flatten scores from [beams, vocab_size] to [beams * vocab_size] to get top k, 
    # and reconstruct beam indices and token indices
    scores, indices = torch.topk(scores.reshape(-1), beam_size)
    beam_indices  = torch.divide   (indices, vocab_size, rounding_mode='floor') # indices // vocab_size
    token_indices = torch.remainder(indices, vocab_size)                        # indices %  vocab_size

    # Build the next decoder input
    next_decoder_input = []
    for beam_index, token_index in zip(beam_indices, token_indices):
        prev_decoder_input = decoder_input[beam_index]
        if prev_decoder_input[-1]==EOS_IDX:
            token_index = EOS_IDX # once EOS, always EOS
        token_index = torch.LongTensor([token_index])
        next_decoder_input.append(torch.cat([prev_decoder_input, token_index]))
    decoder_input = torch.vstack(next_decoder_input)

    # If all beams are finished, exit
    if (decoder_input[:, -1]==EOS_IDX).sum() == beam_size:
        break

    # Encoder output expansion from the second time step to the beam size
    if i==0:
        encoder_output = encoder_output.expand(beam_size, *encoder_output.shape[1:])
        
# convert the top scored sequence to a list of text tokens
decoder_output, _ = max(zip(decoder_input, scores), key=lambda x: x[1])
decoder_output = decoder_output[1:].numpy() # remove SOS

output_text_tokens = [target_vocab.tokens[i] for i in decoder_output if i != EOS_IDX] # remove EOS if exists

4 BLEU Score Calculation

I used a test dataset from Multi30k to calculate the BLEU score.

test_dataset = load_dataset('Multi30k', split='test', ('de', 'en'))

We collect model-predicted text tokens and target text tokens:

outputs = []
targets = []

for source_text, target_text in tqdm(test_dataset):
    output = translator(source_text)
    outputs.append(output)

    target = [target_vocab.tokenize(target_text)]
    targets.append(target)

In this case, we only have one reference sequence per prediction. If we had more reference sequences per prediction, targets would be a list of lists of sequences. Please look at this article for more details on the BLEU score calculation.

PyTorch has the bleu_score functions so I used it:

from torchtext.data.metrics import bleu_score

score = bleu_score(outputs, targets)

I got a BLEU score of 0.347 from the model I trained in the previous article.

5 References

  • The Annotated Transformer
    Harvard NLP
  • Google’s Neural Machine Translation System
    Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V. Le, Mohammad Norouzi, Wolfgang Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, Jeff Klingner, Apurva Shah, Melvin Johnson, Xiaobing Liu, Łukasz Kaiser, Stephan Gouws, Yoshikiyo Kato, Taku Kudo, Hideto Kazawa, Keith Stevens, George Kurian, Nishant Patil, Wei Wang, Cliff Young, Jason Smith, Jason Riesa, Alex Rudnick, Oriol Vinyals, Greg Corrado, Macduff Hughes, Jeffrey Dean