キカベン
機械学習でより便利な世の中へ
G検定対策
お問い合わせ
   

CGAN (Conditional GAN):条件付き敵対的生成ネットワークで出したい画像を指定

thumb image

これまでシンプルなGAN畳み込みを加えたGANを解説してきましたが、どれも生成される画像はランダムに決定されるものでした。

CGAN(Conditional GAN、条件付きGAN)では、ラベル情報を追加のパラメータとして生成ネットワークに与えることで出したい画像を指定することができます。

これまで同様、MNISTのデータを使用し、モデルのコードをなるべくシンプルに保ちながら改良を加えています。

1番のポイントは、今まで使用していなかったラベルを条件として活用する点です。

1. 条件を特徴量に変換する🔝

1.1. ラベルのone-hotエンコーディング🔝

ラベル(labels)の値は0から9なので、これをベクトルに変換しネットワークへの入力として扱いやすいようにします。

PyTorchのF.one_hotを利用します。

import torch
from torch.nn import functional as F


# ラベルの定義。例えば、1と3が入っている。
labels = torch.LongTensor([1, 3])

# ラベルをone-hotエンコーディングする
encoded = F.one_hot(labels, num_classes=10)

print(encoded)

結果は、tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]])と表示され、1と3がそれぞれone-hotエンコーディングされているのがわかります。

num_classes=10と指定しているのは、0から9までの10個の数字を扱うからでベクトルの中に入る数字の数を指定していることになります。

one-hotエンコーディングでは、ベクトルの中の一つだけが1で他はすべて0になっています。

0-9をエンコーディングする場合は、どこに1があるかで数が決まるようになります。

1.2. ラベル条件を特徴量に変換する🔝

one-hotエンコーディングされたラベルを特徴量として扱えるように学習するためのクラスを定義します。

ここでは、乱数を特徴量にしたのと同じ方法で、全結合層を使います。

# 条件を特徴量としてベクトル化する
class Condition(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()

        # one-hotエンコーディングから特徴量へ: 10 => 784
        self.fc = nn.Sequential(
            nn.Linear(10, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))
        
    def forward(self, labels: torch.Tensor):
        # 条件であるラベルをone-hotエンコーディングする
        x = F.one_hot(labels, num_classes=10)

        # 整数型(Long)から浮動小数点型(Float)に
        x = x.float()

        # 特徴量へ変換
        return self.fc(x)

このラベルの特徴量を条件として生成ネットワークに入力します。

2. ラベルを条件として与えることの意味🔝

識別ネットワークにラベルを条件として与えると、本物か偽物かの判断がラベルにそったものになるように訓練されます。

例えば、1というラベルが条件として与えられた時に、識別ネットワークに与えられるのは1の画像だけなので、それを本物として判断できるように訓練されるうちに、1というラベルに対応した画像の真偽を判断できるようになります。

生成ネットワークにも識別ネットワークと共通の条件を与えて訓練します。

よって、生成ネットワークが識別ネットワークから良い評価を得るためには、1というラベルが条件の時は本物の1の画像に似たものを生成するように学習する必要があるわけです。

逆に、1というラベルの条件が与えられているのに他の数字の画像を生成したら損失値が大きくなり、そこでさらに学習が行われます。

0から9の全てのラベルで条件付けして訓練していけば生成ネットワークは条件に沿った画像を生成できるようになります。

犬が「お手」や「お座り」を餌を与えながら訓練されるのとちょっと似ているかもしれません。

3. 生成ネットワークに条件を与える🔝

生成ネットワークでは、以前と同様に、乱数の特徴量を生成します。

さらに、条件となるラベル(labels)を特徴量に変換します。

この二つの特徴量を生成ネットワークに与えることで、ラベルに対応した数字の画像をランダムに生成するように訓練することができます。

では、先ほど定義したConditionを生成ネットワークであるGeneratorに組み込みましょう。

# 生成ネットワーク
class Generator(nn.Module):
    def __init__(self, sample_size: int, alpha: float):
        super().__init__()

        # sample_size => 784
        self.fc = nn.Sequential(
            nn.Linear(sample_size, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))

        # 784 => 16 x 7 x 7 
        self.reshape = Reshape(16, 7, 7)

        # 16 x 7 x 7 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(16, 32, kernel_size=5, stride=2, padding=2,
                                       output_padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 1 x 28 x 28
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2, padding=2,
                                      output_padding=1, bias=False),
            nn.Sigmoid())
            
        # 乱数生成用
        self.sample_size = sample_size

        # ラベルを条件特徴量に変換する用
        self.cond = Condition(alpha)

    def forward(self, labels: torch.Tensor):
        # ラベル条件の特徴量
        c = self.cond(labels)

        # バッチサイズはラベルの数で決まる
        batch_size = len(labels)

        # 乱数インプットを生成
        z = torch.randn(batch_size, self.sample_size)

        # 乱数の特徴量にラベル条件の特徴量を足して推論
        x = self.fc(z)        # => 784
        x = self.reshape(x+c) # => 16 x 7 x 7
        x = self.conv1(x)     # => 32 x 14 x 14
        x = self.conv2(x)     # => 1 x 28 x 28
        return x

このコードでは、ラベルの特徴量(c)を乱数の特徴量(x)と同じサイズのベクトルにして乱数の特徴量と要素ごとの加算(x+c)を行なっています。

これは、「乱数の特徴量もラベルの特徴量も画像を生成するための情報を同じ空間に持つ」との考えからです。

そこの部分だけ抜き出してみます。

# ラベルの特徴量(784)
c = self.cond(labels)

# 乱数の特徴量(784)
z = torch.randn(batch_size, self.sample_size)
x = self.fc(z)

# たしあわせて16x7x7に変換
x = self.reshape(x+c)

なお、ラベルの特徴量だけでも数字の画像は生成できるように訓練することは可能ですが、それだとバリエーションがないので乱数の特徴量を混ぜているという解釈もできます。

この辺りは、実装によって異なるアプローチが取られています。

例えば、乱数とラベルのone-hotエンコーディングされたものを連結(concatenate)するのも一つの方法です。

この場合は、乱数とラベルは別物だとして、100個の乱数とone-hotエンコーディングから10個の数字を110個の数として連結して扱う形になります。

あるいは、乱数の特徴量とラベルの特徴量を別物だと考え連結するのも一つの方法です。

この場合、特徴量のサイズが同じである必要はありません。

特徴量のサイズが両方とも784ままだと連結する際に大きすぎると考えられるので、特徴量のサイズを両方ともあるいは片方だけ縮小しても良いかもしれません。

今回の実装では、乱数もラベルも特徴量としては同じ空間にあり数字の画像を生成するのに必要な情報を含んでいるとの考えを取っており、前回からの実装の変更も最小限にするために上記のコードのような方法にしました。

4. 識別ネットワークに条件を与える🔝

識別ネットワークでも生成ネットワークと同様に考えてラベルにそった条件を加えます。

これによって識別ネットワークの判断が常に与えられたラベルの条件のもとに行われるように訓練することができます。

# 識別ネットワーク
class Discriminator(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()
        
        # 1 x 28 x 28 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2, bias=False),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 16 x 7 x 7
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=5, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(alpha))

        # 16 x 7 x 7 => 784
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha),
            nn.Linear(784, 1))

        # ラベル条件の特徴量: 784 => 16 x 7 x 7 
        self.cond = nn.Sequential(
            Condition(alpha),
            Reshape(16, 7, 7))

    def forward(self, images: torch.Tensor, labels: torch.Tensor, targets: torch.Tensor):
        # ラベル条件の特徴量
        c = self.cond(labels)

        # 画像の特徴量にラベル条件の特徴量を足して推論
        x = self.conv1(images)    # => 32 x 14 x 14
        x = self.conv2(x)         # => 16 x 7 x 7
        prediction = self.fc(x+c) # => 1

        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss

5. CGANの訓練🔝

訓練を行うときには、生成ネットワークと識別ネットワークの両方に同じ条件を与えて訓練していきます。

5.1. 識別ネットワークの訓練🔝

識別ネットワークの訓練は以前とほぼ同じで本物と偽物を交互に与えますが、distriminatorgeneratorの両方に同じlabelsを与えている点が以前は行われなかった点で異なります。

# 訓練ループ
for epoch in range(100):

    for images, labels in tqdm(dataloader):

        # MNISTからの画像(本物)なら正解はtrue_targets
        d_loss = discriminator(images, labels, true_targets)
       
        # 生成ネットワークからの画像(偽物)なら正解はfalse_targets
        d_loss += discriminator(generator(labels), labels, false_targets)

        ...

5.2. 生成ネットワークの訓練🔝

生成ネットワークの訓練でも同様にlabelsを活用しています。

# 訓練ループ
for epoch in range(100):

    for images, labels in tqdm(dataloader):

        ...

        # 生成ネットワークからの画像を本物として訓練する
        g_loss = discriminator(generator(labels), labels, true_targets)

        ...

全てをまとめたコードは最後に紹介します。

その前に、実行結果を見ていきましょう。

6. 条件付きGANで画像を生成🔝

6.1. テスト入力値🔝

0から9までの数字を8個ずつ指定して生成しました。

# 0から9の数字のリスト
labels = list(range(10))

# ラベルは整数型にする必要がある
labels = torch.LongTensor(labels)

# 同じデータを8回繰り返す
labels = labels.repeat(8)

# 上記のままだと、10x8の配列になるので一列にする
labels = labels.flatten()

# 画像の生成
generated_images = generator(labels)

# 結果をセーブ
save_image_grid(epoch, generated_images, nrow=10)

上記コメントにもありますが、以下のようにテスト用の条件ラベルを作ります。

  • list(range(10))で0から9の数字が入ったリストを作る。
  • torch.LongTensorで整数型のテンソルにする。
  • これをrepeat(8)で8回繰り返し合計80個のラベルデータを作成。
  • 最後にflatten()でデータローダーからのラベルと同様の形式にする。

6.2. 実行結果🔝

100個のエポックで訓練を実行し、エポックごとに結果をセーブしました。

1エポックで生成画像は既に数字っぽくなっています。条件を与えたことで学習がしやすくなったようです。

50エポック目では生成ネットワークがほぼ出来上がっています。

100エポック後でも50エポックの時とさほど変わらない感じがする。

以上、条件付きGANをMNISTで実験した結果でした。

これでMNISTっぽい画像を好きなだけ思いのままに生成でき、無限の訓練データを作れますね。 😁

7. ソースコード🔝

以前のコードとほぼ同じですが、条件を付け加える部分だけ変更しています。


import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm


# 一般コンフィグ
batch_size  = 64

# 生成ネットワークのコンフィグ
sample_size = 100    # 乱数サンプルの次元
g_alpha     = 0.01   # LeakyReLU alpha
g_lr        = 1.0e-4 # 学習率

# 識別ネットワークのコンフィグ
d_alpha     = 0.01   # LeakyReLU alpha
d_lr        = 1.0e-4 # 学習率


# MNIST用のデータローダー
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)


# 条件を特徴量としてベクトル化する
class Condition(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()

        # one-hotエンコーディングから特徴量へ: 10 => 784
        self.fc = nn.Sequential(
            nn.Linear(10, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))
        
    def forward(self, labels: torch.Tensor):
        # 条件であるラベルをone-hotエンコーディングする
        x = F.one_hot(labels, num_classes=10)

        # 整数型(Long)から浮動小数点型(Float)に
        x = x.float()

        # 特徴量へ変換
        return self.fc(x)


# シェイプ変更のヘルパー
class Reshape(nn.Module):
    def __init__(self, *shape):
        super().__init__()

        self.shape = shape

    def forward(self, x):
        return x.reshape(-1, *self.shape)


# 生成ネットワーク
class Generator(nn.Module):
    def __init__(self, sample_size: int, alpha: float):
        super().__init__()

        # sample_size => 784
        self.fc = nn.Sequential(
            nn.Linear(sample_size, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))

        # 784 => 16 x 7 x 7 
        self.reshape = Reshape(16, 7, 7)

        # 16 x 7 x 7 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(16, 32, kernel_size=5, stride=2, padding=2,
                                       output_padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 1 x 28 x 28
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2, padding=2,
                                      output_padding=1, bias=False),
            nn.Sigmoid())
            
        # 乱数生成用
        self.sample_size = sample_size

        # ラベルを条件特徴量に変換する用
        self.cond = Condition(alpha)

    def forward(self, labels: torch.Tensor):
        # ラベル条件の特徴量
        c = self.cond(labels)

        # バッチサイズはラベルの数で決まる
        batch_size = len(labels)

        # 乱数インプットを生成
        z = torch.randn(batch_size, self.sample_size)

        # 乱数の特徴量にラベル条件の特徴量を足して推論
        x = self.fc(z)        # => 784
        x = self.reshape(x+c) # => 16 x 7 x 7
        x = self.conv1(x)     # => 32 x 14 x 14
        x = self.conv2(x)     # => 1 x 28 x 28
        return x


# 識別ネットワーク
class Discriminator(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()
        
        # 1 x 28 x 28 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2, bias=False),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 16 x 7 x 7
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=5, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(alpha))

        # 16 x 7 x 7 => 784
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha),
            nn.Linear(784, 1))

        # ラベル条件の特徴量: 784 => 16 x 7 x 7 
        self.cond = nn.Sequential(
            Condition(alpha),
            Reshape(16, 7, 7))

    def forward(self, images: torch.Tensor, labels: torch.Tensor, targets: torch.Tensor):
        # ラベル条件の特徴量
        c = self.cond(labels)

        # 画像の特徴量にラベル条件の特徴量を足して推論
        x = self.conv1(images)    # => 32 x 14 x 14
        x = self.conv2(x)         # => 16 x 7 x 7
        prediction = self.fc(x+c) # => 1

        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss


# 画像グリッドをセーブするため
def save_image_grid(epoch: int, images: torch.Tensor, nrow: int):
    image_grid = make_grid(images, nrow)     # 複数イメージをグリッド上にまとめた画像
    image_grid = image_grid.permute(1, 2, 0) # チャンネルの次元を一番最後に移動
    image_grid = image_grid.cpu().numpy()    # Numpyへと変換

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(f'generated_{epoch:03d}.jpg')
    plt.close()


# 本物・偽物
true_targets = torch.ones(batch_size, 1)
false_targets = torch.zeros(batch_size, 1)
    

# 生成ネットワークと識別ネットワーク
generator = Generator(sample_size, g_alpha)
discriminator = Discriminator(d_alpha)


# オプティマイザー
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)


# 訓練ループ
for epoch in range(100):

    d_losses = []
    g_losses = []

    for images, labels in tqdm(dataloader):

        #===============================
        # 識別ネットワーク訓練
        #===============================

        # MNISTからの画像(本物)なら正解はtrue_targets
        d_loss = discriminator(images, labels, true_targets)
       
        # 生成ネットワークからの画像(偽物)なら正解はfalse_targets
        d_loss += discriminator(generator(labels), labels, false_targets)

        # 識別ネットワークのアップデート
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #===============================
        # 生成ネットワーク訓練
        #===============================

        # 生成ネットワークからの画像を本物として訓練する
        g_loss = discriminator(generator(labels), labels, true_targets)

        # 生成ネットワークのアップデート
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # 識別ネットワークと生成ネットワークのロスをログ用に貯めておく
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
        
    # ロス値をプリント
    print(epoch, np.mean(d_losses), np.mean(g_losses))

    # 生成画像をセーブしておく
    labels = torch.LongTensor(list(range(10))).repeat(8).flatten()
    save_image_grid(epoch, generator(labels), nrow=10)

8. 参照🔝

8.1. Conditional Generative Adversarial Nets🔝

Mehdi Mirza、Simon Osindero

https://arxiv.org/pdf/1411.1784.pdf



コメントを残す

メールアドレスは公開されません。