キカベン
機械学習でより便利な世の中へ
G検定対策
お問い合わせ
   

DCGANでMNISTの画像を自動生成してみたらクオリティーがグンと上がった話

thumb image

DCGANでは畳み込みを使うことで生成画像のクオリティーをあげているので、このアイデアを取り入れて「GAN(Generative Adversarial Network):敵対的生成ネットワークとは?」で作ったシンプルなGANモデルを改良したものを紹介します。

1. 参考モデル:DCGAN🔝

PyTorchのDCGANのチュートリアルを参考にして畳み込みを導入することにしました。

dcgan_generator
https://pytorch.org/tutorials/_images/dcgan_generator.png

上記のDCGANのネットワークが生成するのはカラー画像であり、また生成される画像のサイズがMNISTよりも大きいです。

よって、MNIST用にモデルを改良し、MNISTサイズのモノクロ画像を制作するようにしました。

1番のポイントは上記にあるように畳み込み層(TransposeConv2dConv2d)を取り入たことです。

理由としては畳み込みが画像内の位置関係を学習するのに全結合層よりも適しているからです。

2. 生成ネットワークの改良🔝

前回の生成ネットワークは、以下のように単純なもので、生成した乱数を縦28x横28のモノクロ画像へと全結合層を通して変換するものでした。

# 生成ネットワーク
class Generator(nn.Sequential):
    def __init__(self, sample_size: int):
        super().__init__(
            nn.Linear(sample_size, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 784),
            nn.Sigmoid())

        # 乱数生成用
        self.sample_size = sample_size

    def forward(self, batch_size: int):
        # 乱数インプットを生成
        z = torch.randn(batch_size, self.sample_size)

        # 生成ネットワークで出力する
        output = super().forward(z)

        # アウトプットを縦28x横28に変換
        generated_images = output.reshape(batch_size, 1, 28, 28)
        return generated_images

このモデルでは最後に出力のreshapeを行なっていますが、畳み込みを取り入れたモデルに変更する場合はreshapeを最初の方に行っておいてから画像を拡大していく手順を踏みます。

以下に改良した生成ネットワークを掲載します。

# 生成ネットワーク
class Generator(nn.Module):
    def __init__(self, sample_size: int, alpha: float):
        super().__init__()

        # sample_size => 784 
        self.fc = nn.Sequential(
            nn.Linear(sample_size, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))

        # 784 => 16 x 7 x 7
        self.reshape = Reshape(16, 7, 7)

        # 16 x 7 x 7 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(16, 32, kernel_size=5, stride=2, padding=2,
                                       output_padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 1 x 28 x 28
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2, padding=2,
                                      output_padding=1, bias=False),
            nn.Sigmoid())
            
        # 乱数生成用
        self.sample_size = sample_size

    def forward(self, batch_size: int):
        # 乱数インプットを生成
        z = torch.randn(batch_size, self.sample_size)

        x = self.fc(z)      # => 784 
        x = self.reshape(x) # => 16 x 7 x 7
        x = self.conv1(x)   # => 32 x 14 x 14
        x = self.conv2(x)   # => 1 x 28 x 28
        return x

DCGAN同様に、ConvTranspose2dを使って画像サイズを7x7から28x28へと拡大していく構図になっています。

ConvTranspose2dは画像の拡大だけでなく、学習したウェイトを使って画像の生成にも関わっています。

例えるならば、Conv2dの逆の操作だと捉えると理解しやすいかもしれません。

また、Batch Normalizationを使ってネットワークが効率よく学習できるようにしています。

なお、Reshapeは別途用意したものを使っており、全結合層からの1次元の入力をConvTranspose2dで扱う次元に変換しています。

# シェイプ変更のヘルパー
class Reshape(nn.Module):
    def __init__(self, *shape):
        super().__init__()

        self.shape = shape

    def forward(self, x):
        return x.reshape(-1, *self.shape)

入力乱数のサイズ(例えば100)に関わらず、次元が以下のように変化して行きます。

100 
=> 784
=> 16 x 7 x 7   # reshapeが適用される
=> 32 x 14 x 14  # 画像の拡大
=> 1 x 28 x 28   # 画像の拡大

なお、バッチサイズの次元は省略しています。

以上によってGeneratorはモノクロ画像を生成します。

3. 識別ネットワークの改良🔝

前回の識別ネットワークは、以下のように単純なものでした。

# 識別ネットワーク
class Discriminator(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1))

    def forward(self, images: torch.Tensor, targets: torch.Tensor):
        prediction = super().forward(images.reshape(-1, 784))
        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss

与えられた画像を全結合層で一つの値にまとめて画像が本物か偽物かを判断しています。

識別ネットワークは訓練で使う損失関数の値を返しています。

改良した識別ネットワークでは畳み込みを導入し画像から特徴量を引き出してから全結合層で一つの値にまで落とし込みます。

# 識別ネットワーク
class Discriminator(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()
        
        # 1 x 28 x 28 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2, bias=False),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 16 x 7 x 7
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=5, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(alpha))

        # 16 x 7 x 7 => 784
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha),
            nn.Linear(784, 1))

    def forward(self, images: torch.Tensor, targets: torch.Tensor):
        x = self.conv1(images)  # => 32 x 14 x 14
        x = self.conv2(x)       # => 16 x 7 x 7
        prediction = self.fc(x) # => 1

        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss

上記のソースコードにあるようにConv2dを使って画像サイズを28x28から7x7へと縮小し、最後は全結合層で1つの値にまとめます。

生成ネットワークと同様で、識別ネットワークもBatch Normalizationを使ってネットワークが効率よく学習できるようにしています。

4. ビフォーアフター🔝

前回は、100エポック後にこんな生成画像が出来ました。

generated_099

今回は、100エポック後の結果は以下の通りです。

ビフォーアフターで結果と比べると違いが一目瞭然ですね。

5. ソースコード🔝

GeneratorDiscriminator以外は前回とほぼ同じですが、まとめたものを参考に掲載しておきます。


import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm


# 一般コンフィグ
batch_size  = 64

# 生成ネットワークのコンフィグ
sample_size = 100    # 乱数サンプルの次元
g_alpha     = 0.01   # LeakyReLU alpha
g_lr        = 1.0e-4 # 学習率

# 識別ネットワークのコンフィグ
d_alpha     = 0.01   # LeakyReLU alpha
d_lr        = 1.0e-4 # 学習率



# MNIST用のデータローダー
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)


# シェイプ変更のヘルパー
class Reshape(nn.Module):
    def __init__(self, *shape):
        super().__init__()

        self.shape = shape

    def forward(self, x):
        return x.reshape(-1, *self.shape)


# 生成ネットワーク
class Generator(nn.Module):
    def __init__(self, sample_size: int, alpha: float):
        super().__init__()

        # sample_size => 784
        self.fc = nn.Sequential(
            nn.Linear(sample_size, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha))

        # 784 => 16 x 7 x 7 
        self.reshape = Reshape(16, 7, 7)

        # 16 x 7 x 7 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(16, 32, kernel_size=5, stride=2, padding=2,
                                       output_padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 1 x 28 x 28
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2, padding=2,
                                      output_padding=1, bias=False),
            nn.Sigmoid())
            
        # 乱数生成用
        self.sample_size = sample_size

    def forward(self, batch_size: int):
        # 乱数インプットを生成
        z = torch.randn(batch_size, self.sample_size)

        # 乱数の特徴量で推論
        x = self.fc(z)        # => 784 
        x = self.reshape(x)   # => 16 x 7 x 7
        x = self.conv1(x)     # => 32 x 14 x 14
        x = self.conv2(x)     # => 1 x 28 x 28
        return x


# 識別ネットワーク
class Discriminator(nn.Module):
    def __init__(self, alpha: float):
        super().__init__()
        
        # 1 x 28 x 28 => 32 x 14 x 14
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2, bias=False),
            nn.LeakyReLU(alpha))

        # 32 x 14 x 14 => 16 x 7 x 7
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=5, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(alpha))

        # 16 x 7 x 7 => 784
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 784),
            nn.BatchNorm1d(784),
            nn.LeakyReLU(alpha),
            nn.Linear(784, 1))

    def forward(self, images: torch.Tensor, targets: torch.Tensor):
        # 画像の特徴量で推論
        x = self.conv1(images)    # => 32 x 14 x 14
        x = self.conv2(x)         # => 16 x 7 x 7
        prediction = self.fc(x)   # => 1

        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss


# 画像グリッドをセーブするため
def save_image_grid(epoch: int, images: torch.Tensor, nrow: int):
    image_grid = make_grid(images, nrow)     # 複数イメージをグリッド上にまとめた画像
    image_grid = image_grid.permute(1, 2, 0) # チャンネルの次元を一番最後に移動
    image_grid = image_grid.cpu().numpy()    # Numpyへと変換

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(f'generated_{epoch:03d}.jpg')
    plt.close()


# 本物・偽物
true_targets = torch.ones(batch_size, 1)
false_targets = torch.zeros(batch_size, 1)
    

# 生成ネットワークと識別ネットワーク
generator = Generator(sample_size, g_alpha)
discriminator = Discriminator(d_alpha)

# オプティマイザー
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)


# 訓練ループ
for epoch in range(100):

    d_losses = []
    g_losses = []

    for images, labels in dataloader:

        #===============================
        # 識別ネットワーク訓練
        #===============================

        # MNISTからの画像(本物)なら正解はtrue_targets
        d_loss = discriminator(images, true_targets)
       
        # 生成ネットワークからの画像(偽物)なら正解はfalse_targets
        d_loss += discriminator(generator(batch_size), false_targets)

        # 識別ネットワークのアップデート
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #===============================
        # 生成ネットワーク訓練
        #===============================

        # 生成ネットワークからの画像を本物として訓練する
        g_loss = discriminator(generator(batch_size), true_targets)

        # 生成ネットワークのアップデート
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # 識別ネットワークと生成ネットワークのロスをログ用に貯めておく
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())
        
    # ロス値をプリント
    print(epoch, np.mean(d_losses), np.mean(g_losses))

    # 生成画像をセーブしておく
    save_image_grid(epoch, generator(batch_size), nrow=8)

6. 参照🔝

6.1. PyTorch – DCGAN TUTORIAL🔝

Nathan Inkawhich

https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

6.2. Up-sampling with Transposed Convolution🔝

Naoki Shibuya

https://naokishibuya.medium.com/up-sampling-with-transposed-convolution-9ae4f2df52d0



コメントを残す

メールアドレスは公開されません。