View file src/colab/simple_autoencoder_conv.py - Download
# -*- coding: utf-8 -*-
"""simple_autoencoder_conv.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1tXwRSKN6ISEJUFh3AuZ9GJC9yEaQRTd5
This notebook is inspired by the following notebook :
https://colab.research.google.com/github/smartgeometry-ucl/dl4g/blob/master/variational_autoencoder.ipynb
Simple Autoencoder with Convolution
======
"""
# Commented out IPython magic to ensure Python compatibility.
# install pytorch (http://pytorch.org/) if run from Google Colaboratory
import sys
if 'google.colab' in sys.modules and 'torch' not in sys.modules:
from os.path import exists
# from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
# platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
#!pip3 install https://download.pytorch.org/whl/{accelerator}/torch-1.1.0-{platform}-linux_x86_64.whl
#!pip3 install https://download.pytorch.org/whl/{accelerator}/torchvision-0.3.0-{platform}-linux_x86_64.whl
# %matplotlib inline
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
"""Parameter Settings
-------------------
"""
# 2-d latent space, parameter count in same order of magnitude
# as in the original VAE paper (VAE paper has about 3x as many)
latent_dims = 2
num_epochs = 20
batch_size = 128
capacity = 64
learning_rate = 1e-3
use_gpu = True
# # 10-d latent space, for comparison with non-variational auto-encoder
# latent_dims = 10
# num_epochs = 100
# batch_size = 128
# capacity = 64
# learning_rate = 1e-3
# use_gpu = True
"""MNIST Data Loading
-------------------
MNIST images show digits from 0-9 in 28x28 grayscale images. We do not center them at 0, because we will be using a binary cross-entropy loss that treats pixel values as probabilities in [0,1]. We create both a training set and a test set.
"""
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
img_transform = transforms.Compose([
transforms.ToTensor()
])
train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
"""Autoencoder Definition
-----------------------
We use a convolutional encoder and decoder, which generally gives better performance than fully connected versions that have the same number of parameters.
In convolution layers, we increase the channels as we approach the bottleneck, but note that the total number of features still decreases, since the channels increase by a factor of 2 in each convolution, but the spatial size decreases by a factor of 4.
Kernel size 4 is used to avoid biasing problems described here: https://distill.pub/2016/deconv-checkerboard/
"""
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
c = capacity
self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1) # out: c x 14 x 14
self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1) # out: c x 7 x 7
self.fc_mu = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
x_mu = self.fc_mu(x)
return x_mu #, x_logvar
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
c = capacity
self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), capacity*2, 7, 7) # unflatten batch of feature vectors to a batch of multi-channel feature maps
x = F.relu(self.conv2(x))
x = torch.sigmoid(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
return x
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
latent = self.encoder(x)
x_recon = self.decoder(latent)
return x_recon, latent
def ae_loss(recon_x, x, mu):
recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
return recon_loss
ae = Autoencoder()
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
ae = ae.to(device)
num_params = sum(p.numel() for p in ae.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)
"""Train autoencoder
--------
"""
optimizer = torch.optim.Adam(params=ae.parameters(), lr=learning_rate, weight_decay=1e-5)
# set to training mode
ae.train()
train_loss_avg = []
print('Training ...')
for epoch in range(num_epochs):
train_loss_avg.append(0)
num_batches = 0
for image_batch, _ in train_dataloader:
image_batch = image_batch.to(device)
# autoencoder reconstruction
image_batch_recon, latent_mu = ae(image_batch)
# reconstruction error
loss = ae_loss(image_batch_recon, image_batch, latent_mu)
# backpropagation
optimizer.zero_grad()
loss.backward()
# one step of the optmizer (using the gradients from backpropagation)
optimizer.step()
train_loss_avg[-1] += loss.item()
num_batches += 1
train_loss_avg[-1] /= num_batches
print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))
"""Plot Training Curve
--------------------
"""
import matplotlib.pyplot as plt
plt.ion()
fig = plt.figure()
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()
"""Evaluate on the Test Set
-------------------------
"""
# set to evaluation mode
ae.eval()
test_loss_avg, num_batches = 0, 0
for image_batch, _ in test_dataloader:
with torch.no_grad():
image_batch = image_batch.to(device)
# vae reconstruction
image_batch_recon, latent_mu = ae(image_batch)
# reconstruction error
loss = ae_loss(image_batch_recon, image_batch, latent_mu)
test_loss_avg += loss.item()
num_batches += 1
test_loss_avg /= num_batches
print('average reconstruction error: %f' % (test_loss_avg))
"""Visualize Reconstructions
--------------------------
"""
import numpy as np
import matplotlib.pyplot as plt
plt.ion()
import torchvision.utils
ae.eval()
# This function takes as an input the images to reconstruct
# and the name of the model with which the reconstructions
# are performed
def to_img(x):
x = x.clamp(0, 1)
return x
def show_image(img):
img = to_img(img)
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
def visualise_output(images, model):
with torch.no_grad():
images = images.to(device)
images, _ = model(images)
images = images.cpu()
images = to_img(images)
np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
plt.show()
# images, labels = iter(test_dataloader).next()
images, labels = next(iter(test_dataloader))
# First visualise the original images
print('Original images')
show_image(torchvision.utils.make_grid(images[1:50],10,5))
plt.show()
# Reconstruct and visualise the images using the vae
print('VAE reconstruction:')
visualise_output(images, ae)
"""Interpolate in Latent Space
----------------------------
"""
ae.eval()
def interpolation(lambda1, model, img1, img2):
with torch.no_grad():
# latent vector of first image
img1 = img1.to(device)
latent_1 = model.encoder(img1)
# latent vector of second image
img2 = img2.to(device)
latent_2 = model.encoder(img2)
# interpolation of the two latent vectors
inter_latent = lambda1* latent_1 + (1- lambda1) * latent_2
# reconstruct interpolated image
inter_image = model.decoder(inter_latent)
inter_image = inter_image.cpu()
return inter_image
# sort part of test set by digit
digits = [[] for _ in range(10)]
for img_batch, label_batch in test_dataloader:
for i in range(img_batch.size(0)):
digits[label_batch[i]].append(img_batch[i:i+1])
if sum(len(d) for d in digits) >= 1000:
break;
# interpolation lambdas
lambda_range=np.linspace(0,1,10)
fig, axs = plt.subplots(2,5, figsize=(15, 6))
fig.subplots_adjust(hspace = .5, wspace=.001)
axs = axs.ravel()
for ind,l in enumerate(lambda_range):
inter_image=interpolation(float(l), ae, digits[7][0], digits[1][0])
inter_image = to_img(inter_image)
image = inter_image.numpy()
axs[ind].imshow(image[0,0,:,:], cmap='gray')
axs[ind].set_title('lambda_val='+str(round(l,1)))
plt.show()
"""Sample Latent Vector from Prior (Autoencoder as Generator)
-------------------------------------------------
An autoencoder can generate new digits by drawing latent vectors from the prior distribution.
"""
ae.eval()
with torch.no_grad():
# sample latent vectors from the normal distribution
latent = torch.randn(128, latent_dims, device=device)
# reconstruct images from the latent vectors
img_recon = ae.decoder(latent)
img_recon = img_recon.cpu()
fig, ax = plt.subplots(figsize=(5, 5))
show_image(torchvision.utils.make_grid(img_recon.data[:100],10,5))
plt.show()
"""Show 2D Latent Space
---------------------
"""
# load a network that was trained with a 2d latent space
if latent_dims != 2:
print('Please change the parameters to two latent dimensions.')
with torch.no_grad():
# create a sample grid in 2d latent space
latent_x = np.linspace(-1.5,1.5,20)
latent_y = np.linspace(-1.5,1.5,20)
latents = torch.FloatTensor(len(latent_y), len(latent_x), 2)
for i, lx in enumerate(latent_x):
for j, ly in enumerate(latent_y):
latents[j, i, 0] = lx
latents[j, i, 1] = ly
latents = latents.view(-1, 2) # flatten grid into a batch
# reconstruct images from the latent vectors
latents = latents.to(device)
image_recon = ae.decoder(latents)
image_recon = image_recon.cpu()
fig, ax = plt.subplots(figsize=(10, 10))
show_image(torchvision.utils.make_grid(image_recon.data[:400],20,5))
plt.show()