キカベン
機械学習でより便利な世の中へ
G検定対策
お問い合わせ
   

GAN(Generative Adversarial Network):敵対的生成ネットワーク

thumb image

この記事では、GAN(Generative Adversarial Network)の仕組みをPyTorchでコードを書きながら少しづつ積み上げていく形式で解説していきます。

GANを知らない人は「DCGANでファッション画像を生成してみる」に例があるので参考にしてください。

ざっくり言うと、自動でランダムな画像の生成ができるモデルなのですが、今回はそれをどう訓練するのかに焦点を置いています。

このGANはIan GoodfellowYoshua Bengioらによって2014年に提唱されました。

Ian GoodfellowといえばGANを思い浮かべる人も多いと思います。

ちなみに、当時、Ian Goodfellowはモントリオール大学で博士課程の学生でした。

その後、GoogleやOpenAIを経て2019年からはAppleで機械学習部門のディレクターとして活躍しています。

また、GANについて、Yann LeCun(Facebookのチーフ・AIサイエンティスト)が「過去10年間で最も興味深いアイデア」と高い評価をしています。

何がそんなに「興味深いアイデア」なのでしょうか?

では、PyTorchでコードを書きながら、解説していきます。

1. Python環境を作る🔝

まずは、Pythonの環境を作りましょう。Virtualenvで以下のように実行します。

# プロジェクトのフォルダを作成し移動
mkdir gan
cd gan

# VirtualenvでPython環境を作りアクティベートする
python3 -m venv venv
source venv/bin/activate

# 一応、pipをアップデートしておく
pip install --upgrade pip

# 必要なライブラリをインストール
pip install torch torchvision matplotlib tqdm

Condaが好みの方は、Mini Condaなどで環境を作ってください。

まず今回使用するデータセットを説明します。

2. MNIST🔝

MNISTはYann LeCunが畳み込みニューラルネットワークの訓練のために用意した数字のモノクロ画像のデータセットです。

2.1. MINISTのデータ読み込み🔝

以下のようにデータを読み込みます。

from torchvision import datasets, transforms


transform = transforms.ToTensor()

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

訓練用とテスト用のデータがあるのですが、訓練用のデータの方がデータ数が大きいので訓練用のデータを使用します。

ただし、以下で説明するように教師データとして使用しないので本来のラベル(labels)は無視します。

また、トランスフォームでは、Numpyのデータをtorch.Tensorに変更するだけにしています。

これで0-255の整数型データ値が0.0-1.0の浮動小数点型のデータ値になります。

2.2. サンプル画像の表示🔝

では、画像をいくつか見て行きましょう。

DataLoaderを使ってバッチとしてデータを取り出し表示します。

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=64, drop_last=True)

ちなみに、drop_last=Trueとしたのはバッチサイズに満たない最後の方のデータは切り捨てるという意味で、バッチサイズを固定で扱いコードを簡単にするために設定しました。

では、1つのバッチを取り出しましょう。

images, labels = next(iter(dataloader))

print(images.shape)
print(labels)

バッチのシェイプはtorch.Size([64, 1, 28, 28])で一つの画像のサイズが縦28x横28ピクセルであるのが分かります。

また、モノクロなのでチャンネル数は1です。

ラベル(labels)は使いませんが以下のようにバッチサイズ分の数字がプリントされます。

tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1,
        1, 2, 4, 3, 2, 7, 3, 8, 6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9, 3, 9, 8, 5,
        9, 3, 3, 0, 7, 4, 9, 8, 0, 9, 4, 1, 4, 4, 6, 0])

画像を表示する関数を設定して表示してみます。

import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


def show_image_grid(images: torch.Tensor, nrow: int):
    image_grid = make_grid(images, nrow)     # 複数イメージをグリッド上にまとめた画像
    image_grid = image_grid.permute(1, 2, 0) # チャンネルの次元を一番最後に移動
    image_grid = image_grid.cpu().numpy()    # Numpyへと変換

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.show()


show_image_grid(images, nrow=8)

以下のようにラベル(labels)に対応した64個の数字が表示されました。

MNISTのそもそもの目的は、モノクロ画像に描かれている数字を予測するモデルを訓練することですが、今回は本物の数字画像のデータセットとして画像だけを使用します。

これに対して、GANでは画像を生成するネットワークを訓練し、より本物に近い画像が生成できるようにします。

3. 生成ネットワーク🔝

前述の通り、GANでは本物に近い画像をランダムに作成することができます。

この画像を生成するものを生成ネットワーク(Generator)と呼び、乱数を入力することでランダムな画像を生成します。

では、生成ネットワークを作りましょう。

3.1. 入力乱数🔝

まず、入力する乱数の数を決める必要があります。

この数が多いほど画像として生成される数字の形のバリエーションが増えることになりますが、あまり大きいと訓練が長引いたりうまくいかなくなる可能性があるので、とりあえずは100個にしておきます。

例えば、100個の乱数を正規分布を使って以下のように生成します。

z = torch.randn(1, 100)
print(z)

これでバッチサイズが1の100個の乱数が出来ました。

tensor([[-1.3960,  0.6781, -0.7288, -1.6172,  0.6932, -0.3400,  0.4452,  0.3502,
         -0.7701, -0.9442, -0.5733,  1.2326, -0.6679, -0.6649,  0.3949,  0.1918,
          0.1251, -0.2411, -1.3586,  1.3770,  0.4997, -0.6191, -1.9267,  0.3402,
          1.2111, -0.3573,  1.0328, -0.1675,  0.0841,  0.1866, -0.1482,  0.2603,
         -0.1154, -0.6616, -1.5474, -0.3432, -0.1312,  0.1223,  1.0606,  1.2120,
         -1.3338,  1.1965,  0.0041,  0.6383, -0.1143, -0.8992,  0.6415, -0.6786,
         -0.0174, -1.1782, -0.6206,  1.2067,  0.2221,  0.3988, -0.7581,  1.4411,
         -0.1658,  0.2643,  1.8042,  0.4923,  0.1234, -0.1523,  2.0511,  1.0947,
         -0.9983, -0.7883, -0.1812, -0.1829,  1.4517, -1.2220,  0.3964,  0.0781,
          0.2261,  0.5814,  1.5786, -0.3531,  0.3415,  3.2510, -0.5528, -0.4402,
         -1.9231, -0.3097, -0.0519, -0.8633,  0.6243, -1.4232, -1.7594,  1.2454,
         -0.3119, -0.8461,  0.8073,  1.0772, -1.1928, -1.5024, -0.5267, -0.1670,
          0.5459,  0.0671, -0.4532, -0.8306]])

このデータを入力値として縦28x横28のモノクロ画像を生成するネットワークを定義します。

3.2. 潜在変数🔝

ちょっと話はそれますが、論文などでは、この乱数を潜在変数(latent variable)と呼びます。

画像などの観測できるデータに秘められた直接観測できないが重要な情報といった意味があります。

逆に、潜在変数には、観測できる画像として復元できるほどの情報が秘められているということになります。

ただし、この記事では単純にランダムな数字の羅列という捉え方をします。

つまり、訓練されたネットワークは乱数から画像への変換ができるという考え方です。

乱数を潜在変数として扱えるようになると言ってもいいかもしれません。

もちろん、研究テーマとして潜在変数の構造などを突き詰めていく方向は大事ではあります。

しかし、潜在変数という言葉を深追いしなくてもGANの訓練の仕組みを理解することは可能です。

では、乱数を与えるとネットワークが本物っぽい画像をランダムに生成できるように訓練していきましょう。

3.3. ネットワーク定義🔝

今回は単純なネットワークを定義しました。

from torch import nn


generator = nn.Sequential(nn.Linear(100, 128),
                          nn.LeakyReLU(0.01),
                          nn.Linear(128, 784),
                          nn.Sigmoid())

最後にSigmoidがあるのは出力値を0.0-1.0に限定し、MNISTからロードしたデータと同じ範囲にするためです。

また、出力のチャンネル数が784 = 28 x 28なので、reshapeを施せば(1, 28, 28)にすることが出来ます。

では、バッチサイズ64で乱数インプットを生成し、生成ネットワークの出力を表示してみます。

def generate_images():
    # 乱数インプットを生成
    z = torch.randn(64, 100)

    # 生成ネットワークで出力する
    output = generator(z)

    # アウトプットを縦28x横28に変換
    generated_images = output.reshape(64, 1, 28, 28)

    return generated_images


# 画像を生成
generated_images = generate_images()

# 画像を表示
show_image_grid(generated_images, nrow=8)

以下のように、生成された画像は全てノイズにしか見えません。

この生成ネットワークを訓練して本物のMNISTの画像のように数字を生み出すように訓練するのですが、GANでは訓練の方法に特徴があります。

Yann LeCunが指摘した「最も興味深いアイデアである」この訓練方法では、本物の画像と偽物の画像を見極めることが必要不可欠となります。

4. 識別ネットワーク🔝

識別ネットワーク(Discriminator)は与えられた画像を見て本物(Real)か偽物(Fake)かを判断します。

本物を1、偽物を0と判定すると、これは二項分類(Binary Classification)の問題として扱うことができます。

つまり、識別ネットワークは二項分類を行うネットワークであり、そのために訓練されます。

4.1. ネットワーク定義🔝

今回は以下のような単純なものを準備します。

discriminator = nn.Sequential(nn.Linear(784, 128),
                              nn.LeakyReLU(0.01),
                              nn.Linear(128, 1))

縦28x横28の画像をフラット(784)にして入力すると、識別ネットワークが一つの画像に対して一つの数値を返します。

この数値が大きいほど「画像が本物らしい」ということになります。

では、先ほど生成ネットワークが生成した画像を与えてみます。

discriminator.eval()
with torch.no_grad():
    prediction = discriminator(generated_images.reshape(-1, 784))

print(prediction)

ここで出力される値がバラバラで、識別ネットワークが訓練されていないので意味がありません。

また、generated_imagesの代わりに、MNISTからのバッチを入力しても同様です。

つまり、この段階では識別ネットワークは与えられた画像がMNISTからの本物か生成ネットワークからの偽物なのかは判別できていません。

4.2. 損失関数🔝

訓練するためには、本物と偽物の画像と正解不正解を合わせた教師データを使います。

まず正解不正解のターゲットを用意します。

true_targets = torch.ones(64, 1)
false_targets = torch.zeros(64, 1)

バッチサイズはデータローダーに合わせて64にしました。

MINISTからの画像を入力した場合はtrue_targetsを使い、生成ネットワークからの画像はfalse_targetsを使うことで識別ネットワークを訓練して行きます。

これらを使って損失関数を定義します。

def calculate_loss(images: torch.Tensor, targets: torch.Tensor):
    prediction = discriminator(images.reshape(-1, 784))
    loss = F.binary_cross_entropy_with_logits(prediction, targets)
    return loss

4.3. 訓練ループ🔝

まず、オプティマイザーを設定しておきます。

from torch.nn import functional as F

d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1.0e-4)

訓練では識別ネットワークが与えられた画像が本物か偽物かを判断し、損失関数から得たロス値を使って誤差逆伝播を行いネットワークのウェイト値をオプティマイザーが調整します。

# 訓練モード
discriminator.train()

for epoch in range(100):
    for images, labels in dataloader:
        # MNISTからの画像(本物)なら正解はtrue_targets
        d_loss = calculate_loss(images, true_targets)
  
        # 生成ネットワークからの画像(偽物)なら正解はfalse_targets
        generated_images = generate_images()
        d_loss += calculate_loss(generated_images, false_targets)

        # 識別ネットワークのアップデート
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

この辺りは通常の教師あり学習と同じです。

識別ネットワークは本物と偽物が区別できるようになっていきます。

でも、このままだが生成ネットワークは訓練されていませんね。

どうすれば良いのでしょうか?

5. 敵対的生成ネットワークの訓練🔝

実は、GANの訓練では識別ネットワークと生成ネットワークを同時に訓練します。

まずは、生成ネットワークを訓練する方法を見てみましょう。

5.1. 生成ネットワークの訓練🔝

生成ネットワークを訓練する用にオプティマイザーを設定します。

g_optimizer = torch.optim.Adam(generator.parameters(), lr=1.0e-4)

このオプティマイザーではgeneratorのパラメータをアップデートするように設定しています。

生成ネットワークが本物らしい画像を生成するように正解はtrue_targetsを使用し、識別ネットワークの訓練と交互に訓練するようにします。

# 訓練モード
discriminator.train()

for epoch in range(100):
    for images, labels in dataloader:
        # 識別ネットワークの訓練

        ... 省略 ...

        # 生成ネットワークからの画像を本物として訓練する
        generated_images = generate_images()
        g_loss = calculate_loss(generated_images, true_targets)

        # 生成ネットワークのアップデート
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        g_losses.append(loss.item())

生成ネットワークと識別ネットワークが両方とも進歩していけばWin Winなのですが、GANの訓練は難しいので自分でやってみると結構苦労します。

5.2. 全部まとめる🔝

ソースコードをリファクタリングして全部まとめたのが以下になります。

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm


# コンフィグ
epochs      = 100
batch_size  = 64
sample_size = 100    # 乱数サンプルの次元
g_lr        = 1.0e-4 # 生成ネットワークの学習率
d_lr        = 1.0e-4 # 識別ネットワークの学習率


# MNIST用のデータローダー
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)


# 生成ネットワーク
class Generator(nn.Sequential):
    def __init__(self, sample_size: int):
        super().__init__(
            nn.Linear(sample_size, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 784),
            nn.Sigmoid())

        # 乱数生成用
        self.sample_size = sample_size

    def forward(self, batch_size: int):
        # 乱数インプットを生成
        z = torch.randn(batch_size, self.sample_size)

        # 生成ネットワークで出力する
        output = super().forward(z)

        # アウトプットを縦28x横28に変換
        generated_images = output.reshape(batch_size, 1, 28, 28)
        return generated_images


# 識別ネットワーク
class Discriminator(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1))

    def forward(self, images: torch.Tensor, targets: torch.Tensor):
        prediction = super().forward(images.reshape(-1, 784))
        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss


# 画像グリッドをセーブするため
def save_image_grid(epoch: int, images: torch.Tensor, nrow: int):
    image_grid = make_grid(images, nrow)     # 複数イメージをグリッド上にまとめた画像
    image_grid = image_grid.permute(1, 2, 0) # チャンネルの次元を一番最後に移動
    image_grid = image_grid.cpu().numpy()    # Numpyへと変換

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(f'generated_{epoch:03d}.jpg')
    plt.close()


# 本物・偽物
true_targets = torch.ones(batch_size, 1)
false_targets = torch.zeros(batch_size, 1)


# 生成ネットワークと識別ネットワーク
generator = Generator(sample_size)
discriminator = Discriminator()


# オプティマイザー
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)


# 訓練ループ
for epoch in range(epochs):

    d_losses = []
    g_losses = []

    for images, labels in tqdm(dataloader):
        #===============================
        # 識別ネットワーク訓練
        #===============================

        # MNISTからの画像(本物)なら正解はtrue_targets
        d_loss = discriminator(images, true_targets)

        # 生成ネットワークからの画像を偽物として識別ネットワークを訓練する
        d_loss += discriminator(generator(batch_size), false_targets)

        # 識別ネットワークのアップデート
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #===============================
        # 生成ネットワーク訓練
        #===============================

        # 生成ネットワークからの画像を本物として生成ネットワークを訓練する
        g_loss = discriminator(generator(batch_size), true_targets)

        # 生成ネットワークのアップデート
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # 識別ネットワークと生成ネットワークのロスをログ用に貯めておく
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

    # ロス値をプリント
    print(epoch, np.mean(d_losses), np.mean(g_losses))

    # 生成画像をセーブしておく
    save_image_grid(epoch, generator(batch_size), nrow=8)

GANの訓練では生成ネットワークと識別ネットワークがまるでライバルであるかのように扱われています。

敵対的生成ネットワークという名前の由来がまさにそれであり、またYann LeCunが「最も興味深い」と言ったのもこのことでした。

6. まとめ🔝

6.1. 実行結果🔝

1エポックで生成画像はこんな感じでした。蚕の集まりにしか見えません。

50エポック目でこんな感じ。いくつか数字っぽいのが出現していますね。

100エポックの後で生成画像はこんな感じになりました。

数字に見えるものと見えないものがありますが、最初はノイズだけだったことを考えると良い方向に向かっているようです。

6.2. 改善するには🔝

今回はとても単純なネットワークを使ったので、層を増やしたりしてより複雑にすると効果がありかもしれません。

また、画像を扱うので畳み込みを取り入れたネットワークにする方が良い性能が出るでしょう。

DCGANなどはまさに畳み込みを取り入れて性能を良くしたものになります。

GANの訓練では生成ネットワークと識別ネットワークの両方をバランスよく訓練する必要があり学習率を少し調節しただけで大きな影響が出たりします。

なので、訓練をするのがとても難しいのですが、それでもオプティマイザーを変えてみたり学習率を調節したり、色々と試せることはあると思います。

その際、損失値のグラフを描いて分析したりするのも一考です。

GANにはたくさんの種類があるので、色々と試してみると何か発見があるかもしれません。

7. 参照🔝

7.1. Generative Adversarial Networks🔝

Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio

https://arxiv.org/abs/1406.2661

How to Train a GAN? Tips and tricks to make GANs work

Facebook AI Research: Soumith Chintala, Emily Denton, Martin Arjovsky, Michael Mathieu

https://github.com/soumith/ganhacks

7.3. Understanding Generative Adversarial Networks🔝

Naoki Shibuya

https://naokishibuya.medium.com/understanding-generative-adversarial-networks-4dafc963f2ef



コメントを残す

メールアドレスは公開されません。