キカベン
機械学習でより便利な世の中へ
G検定対策
お問い合わせ
   

DCGANでファッション画像を生成してみる

thumb image

今回は、Torch HubからダウンロードできるDCGANを試します。

DCGANのDCはDeep Convolutionalで、GANはGenerative Adversarial Networkの略です。

論文は、Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networksで、PyTorchの生みの親であるSoumith Chintalaも執筆者の一人に入っています。

このDCGANでは乱数をインプットとして与えることでリアルなファッション関連の画像を生成することができます。

GANというと難しそうな印象を受けるかもしれませんが、使うだけなら簡単です。

1. Python環境を作る🔝

今回は、VirtualenvでPythonの環境を作りましょう。

# プロジェクトのフォルダを作成し移動
mkdir dcgan
cd dcgan

# VirtualenvでPython環境を作りアクティベートする
python3 -m venv venv
source venv/bin/activate

# 一応、pipをアップデートしておく
pip install --upgrade pip

# 必要なライブラリをインストール
pip install torch torchvision matplotlib

Condaが好みの方は、Mini Condaなどで環境を作ってください。

2. DCGANをダウンロード🔝

それでは、DCGANをTorch Hubからダウンロードしましょう。

import torch

use_gpu = True if torch.cuda.is_available() else False

model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub',
                       'DCGAN',
                       pretrained=True,
                       useGPU=use_gpu)

これでモデルがダウンロードされ、~/.cache/torch/hubに格納されます。初回は、ダウンロードに時間がかかりますが、次回からは速くなります。

3. 画像を生成する🔝

画像を生成するには、ノイズを生成して入力データを作成する必要があります。

ここでは、64個の入力データを作ります。

noise, _ = model.buildNoiseData(num_images)

print(noise.shape)

torch.Size([64, 120])と出るので、一つのノイズデータは120の乱数が含まれているのがわかります。

では、画像を生成してみます。

with torch.no_grad():
    generated_images = model.test(noise)

生成された画像を表示しましょう。

torchvision.utilsにはmake_gridという便利な関数があり、複数の画像をグリッド上に配置し一つの画像としてまとめることができます。

import matplotlib.pyplot as plt
from torchvision.utils import make_grid

image_grid = make_grid(generate_images)  # 複数イメージをグリッド上にまとめた画像
image_grid = image_grid.permute(1, 2, 0) # チャンネルの次元を一番最後に移動
image_grid = image_grid.cpu().numpy()    # Numpyへと変換

plt.imshow(image_grid)
plt.xticks([])
plt.yticks([])
plt.show()

64個のノイズデータから64個のファッション画像が生成されました。

入力値である乱数の値によって生成される画像もバリエーションがたくさん生まれるのが確認できるますね。

4. ソースコード🔝

import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


# モデルをダウンロード
use_gpu = True if torch.cuda.is_available() else False

model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=use_gpu)
 
# 乱数のインプットデータを作成
num_images = 64
noise, _ = model.buildNoiseData(num_images)

print('noise shape', noise.shape)
print('noise data', noise[0])

# 画像を生成する
with torch.no_grad():
    generated_images = model.test(noise)

# 画像を表示する
plt.imshow(make_grid(generated_images).permute(1, 2, 0).cpu().numpy())
plt.xticks([])
plt.yticks([])
plt.show()

今日はここまで。

5. 参照🔝

DCGAN on FashionGen | PyTorch: https://pytorch.org/hub/facebookresearch_pytorch-gan-zoo_dcgan/



コメントを残す

メールアドレスは公開されません。