Why Can GPT Learn In-Context?
Language Models Secretly Perform Gradient Descent as Meta-Optimizers
GPT-3 has shown surprising In-Context Learning (ICL) ability, which Why Can GPT Learn In-Context? Language Models Secretly Perform Gradient Descent as Meta-Optimizers explains as a kind of implicit fine-tuning.
With ICL, GPT-3 can learn from a few demonstrations (input-label pairs) and predict the labels for unseen inputs. It can do so without additional parameter updates.
But how does it do that?
According to the paper, they hypothesize that:
- GPT first produces meta-gradients according to the demonstration examples.
- Then, it applies the meta-gradients to the original GPT to build an ICL model.
So, let’s dive into the paper to see how GPT learns in-context.
1 Meta-Gradients
The paper explains that ICL and explicit fine-tuning are both gradient descent. The difference is that explicit fine-tuning uses gradients while ICL uses meta-gradients.
The below figure shows the difference and the similarity.
When we fine-tune a language model, we feed input-label pairs to the model to obtain a loss value. We then calculate the gradients of the loss by back-propagation. After that, we perform gradient descent to update the model parameters.
The upper part of the below figure shows the fine-tuning process.
The lower part of the above figure shows the ICL process that uses meta-gradients instead of gradients.
In ICL, we feed demonstration examples (input-label pairs) to the model, which calculates meta-gradients through forward computation. Then, it applies the meta-gradients to the original GPT to build an ICL model on the fly to answer the query in the last sentence.
In both fine-tuning and ICL, we use gradient descent to update the model parameters. The difference is that ICL does not modify the original model parameters. Instead, it builds a new model on the fly based on the original model. As such, ICL works like a kind of implicit fine-tuning.
But so far, we don’t exactly know what meta-gradients are. Also, what does it mean to build an ICL model through forward computation?
The paper mathematically formulates what ICL is and then explains why ICL works like an implicit fine-tuning.
2 In-Context Learning Classification Formulation
The paper focuses on ICL for classification with GPT models, which consists of stacked L identical Transformer decoder layers. Each decoder layer has an attention module and a feed-forward network.
Given a query input text x and a candidate answer set Y = {y1, y2, …, ym}, the goal of the classification is to choose the correct answer
Let
The final answer
Also, let
Therefore, a formatted input text
The first
We feed
is the output hidden state at the last token position. is the word embedding of an candidate answer . is the logit corresponding to the j-th candidate answer.
So far, we’ve formally defined how ICL classification works.
Next, let’s see how we can understand gradient descent as a linear attention process.
3 Gradient Descent as Attention Layer
The paper compares a linear layer optimized by gradient descent with an attention layer.
A linear layer optimized by gradient descent is as follows:
be the input be the initial model parameters be the parameter update
We compute
is the gradient of the loss with respect to the output of the linear layer at the i-th input. is the transpose of the i-th input.
If it helps, you can think of a linear layer with two neurons and one historic input vector with three values as follows:
We feed-forward the input
Therefore, the gradient of the loss with respect to the initial weights
where
So,
Therefore, we can transform the gradient descent update as follows:
- The historic output gradients
are the values . - The historic inputs
are the keys . - The current input
is the query .
So, the gradient descent update on a linear layer is nothing but a linear attention operation, which takes more gradients from the demonstration inputs similar to the query input.
We can do a similar analysis on the Transformer attention layer to see how it performs gradient descent.
4 Transformer Attention as Meta-Optimization
Now, we’ll see how the Transformer attention mechanism performs meta-optimization.
In the ICL setting (defined earlier), the result of an attention operation from the query token
The above is nothing but the standard attention mechanism with the following parameters:
is the value projection matrix is the key projection matrix is the query projection matrix
The inputs to the attention layer are:
is the demonstration input matrix is the input representations of query tokens before be the input representation of a query token
The key, value, and query are as follows:
is the matrix concatenation of and is the value matrix is the key matrix be the attention query vector
Note:
As per the paper, we relax the attention layer by removing the softmax and scaling factor for simplicity:
The second term
We rewrite the above as follows:
So, the linear attention mechanism creates the zero-shot learning parameters
If we think of
In summary, ICL uses the original language model’s capability to learn from the context of the query input
5 ICL vs. Fine-Tuning Settings
So far, we understand how ICL works as a meta-optimization mechanism in theory.
The paper also conducted experiments to compare the ICL’s meta-optimization with explicit optimization (fine-tuning) by designing a specific fine-tuning setting.
When comparing the attention mechanism with fine-tuning, we treat the query input
Therefore, in the linear attention form, the result of a fine-tuned attention head is as follows:
is the parameter updates to the value projection matrix is the parameter updates to the key projection matrix is the zero-shot learning part
(i.e., the weights from the original language model) is the parameter updates to by fine-tuning
They add the following conditions to the fine-tuning to make the comparison with ICL fair:
- The training examples are the same as ICL’s demonstration examples
- They train each example for only one step in the same order as demonstrated for ICL
- They format each training example using the template for ICL:
- They use the causal language modeling objective for fine-tuning
Note: causal language modeling is the task of predicting the token following a sequence of tokens.
Based on the above settings, ICL has many properties in common with fine-tuning:
- Both perform gradient descent (
and ) - Same training information (training examples = demonstration examples)
- Same causal order of training examples (only one epoch with the same order of training examples as demonstration examples)
- Both updates only the key and value projection matrices
Given all these similarities, it is reasonable to expect that ICL and fine-tuning will perform similarly provided ICL indeed performs implicit fine-tuning.
The paper’s experiments confirm this hypothesis.
6 Experiments
They compared the performance of ICL and fine-tuning with the following tasks:
- Sentiment classification (SST-2, SST-5, MR, Subj)
- Topic classification (AGNews)
- Natural language inference (CB)
They used two GPT-like pre-trained language models with 1.3B and 2.7B parameters, respectively, released by fairseq. They call them GPT 1.3B and GPT 2.7B for short.
The prediction processes of ZSL and fine-tuning are the same as ICL without demonstration examples.
- For ICL,
- The number of demonstration examples is 32.
- They adjusted the random seed for each task to find a set of demonstration examples that achieves the best validation performance.
- For fine-tuning,
- They used the same training examples as ICL, in the same order as demonstrated for ICL.
- Fine-tuning is only one epoch.
- They used an SGD optimizer and adjusted the learning rate to achieve the best validation performance.
Below is the validation accuracy in ZSL, FT (fine-tuning), and ICL on six classification tasks.
Overall, both FT and ICL achieve much better performance than ZSL. In other words, FT and ICL make helpful optimizations to the original language model. Moreover, ICL performs better with few-short scenarios than FT, where only a few examples are available.
So, the accuracy of ICL is comparable (or even better) to fine-tuning in this comparison, which is strong evidence that ICL performs implicit fine-tuning.
7 Measuring Similarity between ICL and Fine-Tuning
They designed three metrics to measure the similarity between ICL and finetuning at three different levels:
7.1 Recall to Fine-tuning Predictions (Rec2FTP)
At the prediction level, Rec2FTP measures how much behavior of finetuning ICL can cover.
: the number of query examples that FT can predict correctly but ZSL cannot. : the number that ICL can also predict correctly.
A higher Rec2FTP score suggests that ICL covers more behavior of finetuning at the prediction level.
7.2 Similarity of Attention Output Updates (SimAOU)
At the representation level, SimAOU measures how much ICL updates the attention output representation in the same direction as finetuning.
Let
We can measure the direction of updates by ICL and FT in the following formula:
Then, we take the cosine similarity between these two updates:
A higher SimAOU score suggests that ICL updates the attention output representation in the same direction as finetuning.
7.3 Similarity of Attention Map (SimAM)
At the attention behavior level, SimAM measures the similarity between the attention maps of ICL and finetuning.
Let
For ICL, they only monitor the attention weights to the query input tokens since FT has no demonstration tokens (every training example is a separate query input).
Then, they take the cosine similarity between these two attention weights:
7.4 Results
Below are the results of the three similarity metrics on six classification tasks.
The Rec2FTP results show that ICL can mostly cover correct predictions by fine-tuning.
The SimAOU and SimAM scores are the average of all examples and layers. There are two baselines:
- Random SimAOU is a baseline metric that measures the similarity between the attention output updates of ICL and randomly generated updates.
- ZSL SimAM is a baseline metric that measures the similarity between the attention weights of ICL and ZSL.
Compared with the baseline metrics, ICL behaves more similarly to fine-tuning at the representation and attention behavior levels.
Overall, ICL exhibits similar behavior on all tasks with fine-tuning.
8 Similarity at Each Layer
Below are the SimAOU and SimAM scores at each layer. They drew box plots based on randomly sampled 50 validation examples from each dataset.
They observed that:
- Both SimAOU and SimAM fluctuate at lower layers
- SimAOU and SimAM increase toward higher layers
The above phenomenon suggests that ICL’s meta-optimization has forward-accumulated effects. As such, ICL behaves more similarly to fine-tuning at higher layers.
9 Optimization Algorithm vs. Transformer Architecture
ICL performs implicit fine-tuning. Then, can we utilize momentum to improve the performance of ICL as we can for SGD optimization of fine-tuning?
They apply Exponential Moving Average (EMA) to the attention values to build momentum-based attention:
is the momentum coefficient (hyperparameter). is the -th attention value vector.
9.1 Experiments in Language Modeling
They trained two GPT models with 350M parameters from scratch. One is the vanilla Transformer, and the other is the momentum-based Transformer.
The below shows the perplexity (lower is better) of the two models with three validation sets with input lengths of 256, 512, and 1024, respectively.
The momentum-based Transformer outperforms the vanilla Transformer on all validation sets.
9.2 Experiments on In-Context Learning
They also evaluated the performance of the momentum-based Transformer on the in-context learning tasks. For the six datasets, they use 32 examples as demonstrations. The momentum-based transformer outperforms the vanilla Transformer on all datasets.
In conclusion, the momentum-based Transformer outperforms the vanilla Transformer in language modeling and in-context learning, proving that introducing momentum into attention is an effective strategy.
10 References
- Why Can GPT Learn In-Context? Language Models Secretly Perform Gradient Descent as Meta-Optimizers
Damai Dai, Yutao Sun, Li Dong, Yaru Hao, Zhifang Sui, Furu Wei - GPT-3: In-Context Few-Shot Learner (2020)