# -*- coding: utf-8 -*-

This notebook is inspired by the following notebook :

Simple Autoencoders with Linear Layers

# Commented out IPython magic to ensure Python compatibility.
# install pytorch ( if run from Google Colaboratory
import sys
if 'google.colab' in sys.modules and 'torch' not in sys.modules:
    from os.path import exists
    # from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
    # platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
    cuda_output = !ldconfig -p|grep|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
    accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

    #!pip3 install{accelerator}/torch-1.1.0-{platform}-linux_x86_64.whl
    #!pip3 install{accelerator}/torchvision-0.3.0-{platform}-linux_x86_64.whl

# %matplotlib inline
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

"""Parameter Settings


# 2-d latent space, parameter count in same order of magnitude
# as in the original VAE paper (VAE paper has about 3x as many)
latent_dims = 2
num_epochs = 20
batch_size = 128
capacity = 64
learning_rate = 1e-3
use_gpu = True

# # 10-d latent space, for comparison with non-variational auto-encoder
# latent_dims = 10
# num_epochs = 100
# batch_size = 128
# capacity = 64
# learning_rate = 1e-3
# use_gpu = True

"""MNIST Data Loading

MNIST images show digits from 0-9 in 28x28 grayscale images. We do not center them at 0, because we will be using a binary cross-entropy loss that treats pixel values as probabilities in [0,1]. We create both a training set and a test set.

import torchvision.transforms as transforms
from import DataLoader
from torchvision.datasets import MNIST

img_transform = transforms.Compose([

train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

"""Autoencoder Definition

Encoder : 28*28=784 -> n1 -> n2 -> n3 -> 2

Decoder : 2 -> n3 -> n2 -> n1 -> 28*28=784


n1 = 512
n2 = 256
n3 = 128

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = capacity
        self.linear1 = nn.Linear(28*28, n1)
        self.linear2 = nn.Linear(n1, n2)
        self.linear3 = nn.Linear(n2, n3)
        self.fc_mu = nn.Linear(n3, latent_dims)

    def forward(self, x):
        x = nn.Flatten()(x)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x_mu = self.fc_mu(x)
        return x_mu

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        self.fc = nn.Linear(latent_dims, n3)
        self.linear3 = nn.Linear(n3, n2)
        self.linear2 = nn.Linear(n2, n1)
        self.linear1 = nn.Linear(n1, 28*28)

    def forward(self, x):
        x = self.fc(x)
        x = F.relu(self.linear3(x))
        x = F.relu(self.linear2(x))
        x = torch.sigmoid(self.linear1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
        x = x.view(x.size(0), 1, 28, 28)
        return x

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        latent = self.encoder(x)
        x_recon = self.decoder(latent)
        return x_recon, latent

def ae_loss(recon_x, x, mu):
    recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    return recon_loss

ae = Autoencoder()

device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
ae =

num_params = sum(p.numel() for p in ae.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

"""Train VAE

optimizer = torch.optim.Adam(params=ae.parameters(), lr=learning_rate, weight_decay=1e-5)

# set to training mode

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    num_batches = 0

    for image_batch, _ in train_dataloader:

        image_batch =

        # vae reconstruction
        image_batch_recon, latent_mu = ae(image_batch)

        # reconstruction error
        loss = ae_loss(image_batch_recon, image_batch, latent_mu)

        # backpropagation

        # one step of the optmizer (using the gradients from backpropagation)

        train_loss_avg[-1] += loss.item()
        num_batches += 1

    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

"""Plot Training Curve

import matplotlib.pyplot as plt

fig = plt.figure()

"""Evaluate on the Test Set

# set to evaluation mode

test_loss_avg, num_batches = 0, 0
for image_batch, _ in test_dataloader:

    with torch.no_grad():

        image_batch =

        # vae reconstruction
        image_batch_recon, latent_mu = ae(image_batch)

        # reconstruction error
        loss = ae_loss(image_batch_recon, image_batch, latent_mu)

        test_loss_avg += loss.item()
        num_batches += 1

test_loss_avg /= num_batches
print('average reconstruction error: %f' % (test_loss_avg))

"""Visualize Reconstructions

import numpy as np
import matplotlib.pyplot as plt

import torchvision.utils


# This function takes as an input the images to reconstruct
# and the name of the model with which the reconstructions
# are performed
def to_img(x):
    x = x.clamp(0, 1)
    return x

def show_image(img):
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

def visualise_output(images, model):

    with torch.no_grad():

        images =
        images, _ = model(images)
        images = images.cpu()
        images = to_img(images)
        np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
        plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))

# images, labels = iter(test_dataloader).next()
images, labels = next(iter(test_dataloader))
# First visualise the original images
print('Original images')

# Reconstruct and visualise the images using the vae
print('VAE reconstruction:')
visualise_output(images, ae)

"""Interpolate in Latent Space


def interpolation(lambda1, model, img1, img2):

    with torch.no_grad():

        # latent vector of first image
        img1 =
        latent_1 = model.encoder(img1)

        # latent vector of second image
        img2 =
        latent_2 = model.encoder(img2)

        # interpolation of the two latent vectors
        inter_latent = lambda1* latent_1 + (1- lambda1) * latent_2

        # reconstruct interpolated image
        inter_image = model.decoder(inter_latent)
        inter_image = inter_image.cpu()

        return inter_image

# sort part of test set by digit
digits = [[] for _ in range(10)]
for img_batch, label_batch in test_dataloader:
    for i in range(img_batch.size(0)):
    if sum(len(d) for d in digits) >= 1000:

# interpolation lambdas

fig, axs = plt.subplots(2,5, figsize=(15, 6))
fig.subplots_adjust(hspace = .5, wspace=.001)
axs = axs.ravel()

for ind,l in enumerate(lambda_range):
    inter_image=interpolation(float(l), ae, digits[7][0], digits[1][0])

    inter_image = to_img(inter_image)

    image = inter_image.numpy()

    axs[ind].imshow(image[0,0,:,:], cmap='gray')

"""Sample Latent Vector from Prior (Autoencoder as Generator)

An autoencoder can generate new digits by drawing latent vectors from the prior distribution.



with torch.no_grad():

    # sample latent vectors from the normal distribution
    latent = torch.randn(128, latent_dims, device=device)

    # reconstruct images from the latent vectors
    img_recon = ae.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(5, 5))

"""Show 2D Latent Space

# load a network that was trained with a 2d latent space
if latent_dims != 2:
    print('Please change the parameters to two latent dimensions.')

with torch.no_grad():

    # create a sample grid in 2d latent space
    latent_x = np.linspace(-1.5,1.5,20)
    latent_y = np.linspace(-1.5,1.5,20)
    latents = torch.FloatTensor(len(latent_y), len(latent_x), 2)
    for i, lx in enumerate(latent_x):
        for j, ly in enumerate(latent_y):
            latents[j, i, 0] = lx
            latents[j, i, 1] = ly
    latents = latents.view(-1, 2) # flatten grid into a batch

    # reconstruct images from the latent vectors
    latents =
    image_recon = ae.decoder(latents)
    image_recon = image_recon.cpu()

    fig, ax = plt.subplots(figsize=(10, 10))