View file src/colab/gan_10g10o.py - Download
# -*- coding: utf-8 -*-
"""gan_10g10o.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1A9JfUcjwARehNw5prHP1fi1g_SY60FT9
"""
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm
# Configuration
epochs = 50
batch_size = 64
sample_size = 100 # Number of random values to sample
g_lr = 1.0e-3 # Generator's learning rate
d_lr = 1.0e-3 # Discriminator's learning rate
# DataLoader for MNIST
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)
# Generator Network
class Generator(nn.Sequential):
def __init__(self, sample_size: int):
super().__init__(
nn.Linear(sample_size, 128),
nn.LeakyReLU(0.01),
nn.Linear(128, 784),
nn.Sigmoid())
# Random value vector size
self.sample_size = sample_size
def forward(self, batch_size: int):
# Generate randon values
z = torch.randn(batch_size, self.sample_size) # 64x100
# Generator output
output = super().forward(z) # 64x784
# Convert the output into a greyscale image (1x28x28)
generated_images = output.reshape(batch_size, 1, 28, 28)
return generated_images
# Discriminator Network
class Discriminator(nn.Sequential):
def __init__(self):
super().__init__(
nn.Linear(784, 128),
nn.LeakyReLU(0.01),
nn.Linear(128, 11))
def forward(self, images: torch.Tensor, labels: torch.Tensor):
targets = torch.nn.functional.one_hot(labels, 11).float()
prediction = super().forward(images.reshape(-1, 784))
loss = F.binary_cross_entropy_with_logits(prediction, targets)
return loss
# To save images in grid layout
def display_image_grid(epoch: int, images: torch.Tensor, ncol: int):
image_grid = make_grid(images, ncol) # Images in a grid
image_grid = image_grid.permute(1, 2, 0) # Move channel last
image_grid = image_grid.cpu().numpy() # To Numpy
plt.imshow(image_grid)
plt.xticks([])
plt.yticks([])
# plt.savefig(f'generated_{epoch:03d}.jpg')
# if (epoch % 10 == 9):
plt.show()
plt.close()
# Real and fake labels
# real_targets = torch.ones(batch_size, 1)
# fake_targets = torch.zeros(batch_size, 1)
def real_targets(digit):
return torch.tensor(digit).repeat(batch_size)
fake_targets = torch.tensor(10).repeat(batch_size)
# Generator and Discriminator networks
# generator = Generator(sample_size)
generators = [Generator(sample_size) for _ in range(10)]
discriminator = Discriminator()
# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
# g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)
# g_parameters = []
# for digit in range(10):
# g_parameters = g_parameters + list(generators[digit].parameters())
# g_optimizer = torch.optim.Adam(g_parameters, lr=g_lr)
g_optimizers = [torch.optim.Adam(generators[digit].parameters(), lr=g_lr) for digit in range(10)]
# Training loop
for epoch in range(epochs):
d_losses = []
g_losses = []
for images, labels in tqdm(dataloader):
#===============================
# Discriminator Network Training
#===============================
# Loss with MNIST image inputs and real_targets as labels
discriminator.train()
d_loss = discriminator(images, labels)
# Generate images in eval mode
for digit in range(10):
generators[digit].eval()
for digit in range(10):
with torch.no_grad():
generated_images = generators[digit](batch_size)
# Loss with generated image inputs and fake_targets as labels
d_loss += discriminator(generated_images, fake_targets)
# Optimizer updates the discriminator parameters
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
#===============================
# Generator Network Training
#===============================
# Generate images in train mode
for digit in range(10):
generators[digit].train()
for digit in range(10):
generated_images = generators[digit](batch_size)
# Loss with generated image inputs and real_targets as labels
discriminator.eval() # eval but we still need gradients
g_loss = discriminator(generated_images, real_targets(digit))
# Optimizer updates the generator parameters
g_optimizers[digit].zero_grad()
g_loss.backward()
g_optimizers[digit].step()
# Keep losses for logging
d_losses.append(d_loss.item())
g_losses.append(g_loss.item())
# Print average losses
print(epoch, np.mean(d_losses), np.mean(g_losses))
# Display images
if epoch % 10 == 9:
for digit in range(10):
print(f"Generate fake {digit}")
display_image_grid(epoch, generators[digit](batch_size), ncol=8)