View file src/colab/simple_autoencoder_linear.py - Download
# -*- coding: utf-8 -*-
"""simple_autoencoder_linear.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1689ZNocHM5MPTU8kcpry5UC7saE0DqAa
This notebook is inspired by the following notebook :
https://colab.research.google.com/github/smartgeometry-ucl/dl4g/blob/master/variational_autoencoder.ipynb
Simple Autoencoders with Linear Layers
======
"""
# Commented out IPython magic to ensure Python compatibility.
# install pytorch (http://pytorch.org/) if run from Google Colaboratory
import sys
if 'google.colab' in sys.modules and 'torch' not in sys.modules:
from os.path import exists
# from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
# platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
#!pip3 install https://download.pytorch.org/whl/{accelerator}/torch-1.1.0-{platform}-linux_x86_64.whl
#!pip3 install https://download.pytorch.org/whl/{accelerator}/torchvision-0.3.0-{platform}-linux_x86_64.whl
# %matplotlib inline
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
"""Parameter Settings
-------------------
"""
# 2-d latent space, parameter count in same order of magnitude
# as in the original VAE paper (VAE paper has about 3x as many)
latent_dims = 2
num_epochs = 20
batch_size = 128
capacity = 64
learning_rate = 1e-3
use_gpu = True
# # 10-d latent space, for comparison with non-variational auto-encoder
# latent_dims = 10
# num_epochs = 100
# batch_size = 128
# capacity = 64
# learning_rate = 1e-3
# use_gpu = True
"""MNIST Data Loading
-------------------
MNIST images show digits from 0-9 in 28x28 grayscale images. We do not center them at 0, because we will be using a binary cross-entropy loss that treats pixel values as probabilities in [0,1]. We create both a training set and a test set.
"""
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
img_transform = transforms.Compose([
transforms.ToTensor()
])
train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
"""Autoencoder Definition
-----------------------
Encoder : 28*28=784 -> n1 -> n2 -> n3 -> 2
Decoder : 2 -> n3 -> n2 -> n1 -> 28*28=784
"""
n1 = 512
n2 = 256
n3 = 128
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
c = capacity
self.linear1 = nn.Linear(28*28, n1)
self.linear2 = nn.Linear(n1, n2)
self.linear3 = nn.Linear(n2, n3)
self.fc_mu = nn.Linear(n3, latent_dims)
def forward(self, x):
x = nn.Flatten()(x)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = F.relu(self.linear3(x))
x_mu = self.fc_mu(x)
return x_mu
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
c = capacity
self.fc = nn.Linear(latent_dims, n3)
self.linear3 = nn.Linear(n3, n2)
self.linear2 = nn.Linear(n2, n1)
self.linear1 = nn.Linear(n1, 28*28)
def forward(self, x):
x = self.fc(x)
x = F.relu(self.linear3(x))
x = F.relu(self.linear2(x))
x = torch.sigmoid(self.linear1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
x = x.view(x.size(0), 1, 28, 28)
return x
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
latent = self.encoder(x)
x_recon = self.decoder(latent)
return x_recon, latent
def ae_loss(recon_x, x, mu):
recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
return recon_loss
ae = Autoencoder()
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
ae = ae.to(device)
num_params = sum(p.numel() for p in ae.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)
"""Train VAE
--------
"""
optimizer = torch.optim.Adam(params=ae.parameters(), lr=learning_rate, weight_decay=1e-5)
# set to training mode
ae.train()
train_loss_avg = []
print('Training ...')
for epoch in range(num_epochs):
train_loss_avg.append(0)
num_batches = 0
for image_batch, _ in train_dataloader:
image_batch = image_batch.to(device)
# vae reconstruction
image_batch_recon, latent_mu = ae(image_batch)
# reconstruction error
loss = ae_loss(image_batch_recon, image_batch, latent_mu)
# backpropagation
optimizer.zero_grad()
loss.backward()
# one step of the optmizer (using the gradients from backpropagation)
optimizer.step()
train_loss_avg[-1] += loss.item()
num_batches += 1
train_loss_avg[-1] /= num_batches
print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))
"""Plot Training Curve
--------------------
"""
import matplotlib.pyplot as plt
plt.ion()
fig = plt.figure()
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()
"""Evaluate on the Test Set
-------------------------
"""
# set to evaluation mode
ae.eval()
test_loss_avg, num_batches = 0, 0
for image_batch, _ in test_dataloader:
with torch.no_grad():
image_batch = image_batch.to(device)
# vae reconstruction
image_batch_recon, latent_mu = ae(image_batch)
# reconstruction error
loss = ae_loss(image_batch_recon, image_batch, latent_mu)
test_loss_avg += loss.item()
num_batches += 1
test_loss_avg /= num_batches
print('average reconstruction error: %f' % (test_loss_avg))
"""Visualize Reconstructions
--------------------------
"""
import numpy as np
import matplotlib.pyplot as plt
plt.ion()
import torchvision.utils
ae.eval()
# This function takes as an input the images to reconstruct
# and the name of the model with which the reconstructions
# are performed
def to_img(x):
x = x.clamp(0, 1)
return x
def show_image(img):
img = to_img(img)
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
def visualise_output(images, model):
with torch.no_grad():
images = images.to(device)
images, _ = model(images)
images = images.cpu()
images = to_img(images)
np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
plt.show()
# images, labels = iter(test_dataloader).next()
images, labels = next(iter(test_dataloader))
# First visualise the original images
print('Original images')
show_image(torchvision.utils.make_grid(images[1:50],10,5))
plt.show()
# Reconstruct and visualise the images using the vae
print('VAE reconstruction:')
visualise_output(images, ae)
"""Interpolate in Latent Space
----------------------------
"""
ae.eval()
def interpolation(lambda1, model, img1, img2):
with torch.no_grad():
# latent vector of first image
img1 = img1.to(device)
latent_1 = model.encoder(img1)
# latent vector of second image
img2 = img2.to(device)
latent_2 = model.encoder(img2)
# interpolation of the two latent vectors
inter_latent = lambda1* latent_1 + (1- lambda1) * latent_2
# reconstruct interpolated image
inter_image = model.decoder(inter_latent)
inter_image = inter_image.cpu()
return inter_image
# sort part of test set by digit
digits = [[] for _ in range(10)]
for img_batch, label_batch in test_dataloader:
for i in range(img_batch.size(0)):
digits[label_batch[i]].append(img_batch[i:i+1])
if sum(len(d) for d in digits) >= 1000:
break;
# interpolation lambdas
lambda_range=np.linspace(0,1,10)
fig, axs = plt.subplots(2,5, figsize=(15, 6))
fig.subplots_adjust(hspace = .5, wspace=.001)
axs = axs.ravel()
for ind,l in enumerate(lambda_range):
inter_image=interpolation(float(l), ae, digits[7][0], digits[1][0])
inter_image = to_img(inter_image)
image = inter_image.numpy()
axs[ind].imshow(image[0,0,:,:], cmap='gray')
axs[ind].set_title('lambda_val='+str(round(l,1)))
plt.show()
"""Sample Latent Vector from Prior (Autoencoder as Generator)
-------------------------------------------------
An autoencoder can generate new digits by drawing latent vectors from the prior distribution.
"""
ae.eval()
with torch.no_grad():
# sample latent vectors from the normal distribution
latent = torch.randn(128, latent_dims, device=device)
# reconstruct images from the latent vectors
img_recon = ae.decoder(latent)
img_recon = img_recon.cpu()
fig, ax = plt.subplots(figsize=(5, 5))
show_image(torchvision.utils.make_grid(img_recon.data[:100],10,5))
plt.show()
"""Show 2D Latent Space
---------------------
"""
# load a network that was trained with a 2d latent space
if latent_dims != 2:
print('Please change the parameters to two latent dimensions.')
with torch.no_grad():
# create a sample grid in 2d latent space
latent_x = np.linspace(-1.5,1.5,20)
latent_y = np.linspace(-1.5,1.5,20)
latents = torch.FloatTensor(len(latent_y), len(latent_x), 2)
for i, lx in enumerate(latent_x):
for j, ly in enumerate(latent_y):
latents[j, i, 0] = lx
latents[j, i, 1] = ly
latents = latents.view(-1, 2) # flatten grid into a batch
# reconstruct images from the latent vectors
latents = latents.to(device)
image_recon = ae.decoder(latents)
image_recon = image_recon.cpu()
fig, ax = plt.subplots(figsize=(10, 10))
show_image(torchvision.utils.make_grid(image_recon.data[:400],20,5))
plt.show()