Fashion GAN
Generating Colorful Fashion Images with DCGAN
This article is a short follow-up on DCGAN (Deep Convolutional Generative Adversarial Network) applied for color fashion images (aka Fashion GAN). We’ll download a pre-trained DCGAN model from the Torch Hub and generate fashion images. The paper Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks introduced the model. One of the authors, Soumith Chintala, is the main contributor to the PyTorch open-source project.
We can feed random inputs into this DCGAN to generate realistic fashion images. Training GAN may be tricky, but it is straightforward to use it. So, let’s do it!
1 Python Environment Setup
First of all, we create a Python environment. We’ll use venv
as follows:
# Create a project folder and move there
mkdir dcgan
cd dcgan
# Create and activate a Python environment using venv
-m venv venv
python3 /bin/activate
source venv
# We should always upgrade pip as it's usually old version
# that has older information about libraries
--upgrade pip
pip install
# We install required libraries under the virtual environment
pip install torch torchvision matplotlib tqdm
The versions of installed libraries are as follows:
matplotlib==3.5.1
torch==1.11.0
torchvision==0.12.0
tqdm==4.64.0
If you prefer conda
, you can create an environment with that. Please make sure to install the required libraries.
2 Download the Pre-Trained Fashion GAN
We’ll download the pre-trained DCGAN model from the Torch Hub.
import torch
= True if torch.cuda.is_available() else False
use_gpu
= torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub',
model 'DCGAN',
=True,
pretrained=use_gpu) useGPU
The model can make use of GPU. If your machine has a GPU, use_gpu
would be True
. However, it requires no GPU if you only want to generate fashion images. If you plan to do transfer learning which involves training, you may want to do so on a machine with GPU. But that is not the topic of this article.
The first time invoking the torch.hub.load
method, it takes a while to download the pre-trained model, and it stores the model under ~/.cache/torch/hub
. After that, it will always load the model from the local cache, and it’s fast. If you can not resolve some problems, removing the downloaded model from the cache may be worthwhile. I do not expect anything to happen when we generate images using the downloaded model.
3 Generate Fashion Images
As with any GAN model, we need to feed randomly generated inputs. The model provides a convenient noise data-generating method (buildNoiseData
). The method is part of BaseGAN class that DCGAN inherits. The PyTorch GAN Zoo GitHub repository has other GAN models inherited from the BaseGAN class.
Here, we are creating 64 input noise data.
= model.buildNoiseData(num_images)
noise, _
print(noise.shape)
The last print statement prints the shape of the noise data (torch.Size([64, 120])
, meaning there are 64 inputs, each having 120 random values.
Let’s generate fashion images.
with torch.no_grad():
= model.test(noise) generated_images
We use make_grid
from torchvision.utils
to show images in a grid layout.
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
= make_grid(generate_images, ncol=8) # Make images into a grid
image_grid = image_grid.permute(1, 2, 0) # Move channel to the last
image_grid = image_grid.cpu().numpy() # Convert into Numpy
image_grid
plt.imshow(image_grid)
plt.xticks([])
plt.yticks([]) plt.show()
Below is the result. We fed 64 noise data into the model to generate 64 random fashion images.
Below is the output of another run.
As you can see, different random inputs generate different fashion images.
4 Source Code
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
# Download the model
= True if torch.cuda.is_available() else False
use_gpu
= torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=use_gpu)
model
# Generate random inputs
= 64
num_images = model.buildNoiseData(num_images)
noise, _
print('noise shape', noise.shape)
print('noise data', noise[0])
# Generate fashion images
with torch.no_grad():
= model.test(noise)
generated_images
# Display images
1, 2, 0).cpu().numpy())
plt.imshow(make_grid(generated_images).permute(
plt.xticks([])
plt.yticks([]) plt.show()
That’s it!
5 References
- Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
Alec Radford, Luke Metz, Soumith Chintala - DCGAN on FashionGen, PyTorch