View file src/colab/tf_probabilistic_layers_vae.py - Download

# -*- coding: utf-8 -*-
"""Copie de Probabilistic Layers VAE

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1jM-CEYHNqCkqcci_U_I3pn0UVqnVnaFk

https://colab.research.google.com/github/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_VAE.ipynb

##### Copyright 2019 The TensorFlow Probability Authors.

Licensed under the Apache License, Version 2.0 (the "License");
"""

#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""# TFP Probabilistic Layers: Variational Auto Encoder

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook
In this example we show how to fit a Variational Autoencoder using TFP's "probabilistic layers." ### Dependencies & Prerequisites """ #@title Import { display-mode: "form" } import numpy as np import tensorflow as tf import tf_keras as tfk import tensorflow_datasets as tfds import tensorflow_probability as tfp tfkl = tfk.layers tfpl = tfp.layers tfd = tfp.distributions """### Make things Fast! Before we dive in, let's make sure we're using a GPU for this demo. To do this, select "Runtime" -> "Change runtime type" -> "Hardware accelerator" -> "GPU". The following snippet will verify that we have access to a GPU. """ if tf.test.gpu_device_name() != '/device:GPU:0': print('WARNING: GPU device not found.') else: print('SUCCESS: Found GPU: {}'.format(tf.test.gpu_device_name())) """Note: if for some reason you cannot access a GPU, this colab will still work. (Training will just take longer.) ### Load Dataset """ datasets, datasets_info = tfds.load(name='mnist', with_info=True, as_supervised=False) def _preprocess(sample): image = tf.cast(sample['image'], tf.float32) / 255. # Scale to unit interval. image = image < tf.random.uniform(tf.shape(image)) # Randomly binarize. return image, image train_dataset = (datasets['train'] .map(_preprocess) .batch(256) .prefetch(tf.data.AUTOTUNE) .shuffle(int(10e3))) eval_dataset = (datasets['test'] .map(_preprocess) .batch(256) .prefetch(tf.data.AUTOTUNE)) """Note that _preprocess() above returns `image, image` rather than just `image` because Keras is set up for discriminative models with an (example, label) input format, i.e. $p_\theta(y|x)$. Since the goal of the VAE is to recover the input x from x itself (i.e. $p_\theta(x|x)$), the data pair is (example, example). ### VAE Code Golf #### Specify model. """ input_shape = datasets_info.features['image'].shape encoded_size = 16 base_depth = 32 prior = tfd.Independent(tfd.Normal(loc=tf.zeros(encoded_size), scale=1), reinterpreted_batch_ndims=1) encoder = tfk.Sequential([ tfkl.InputLayer(input_shape=input_shape), tfkl.Lambda(lambda x: tf.cast(x, tf.float32) - 0.5), tfkl.Conv2D(base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2D(base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2D(2 * base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2D(2 * base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2D(4 * encoded_size, 7, strides=1, padding='valid', activation=tf.nn.leaky_relu), tfkl.Flatten(), tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size), activation=None), tfpl.MultivariateNormalTriL( encoded_size, activity_regularizer=tfpl.KLDivergenceRegularizer(prior)), ]) decoder = tfk.Sequential([ tfkl.InputLayer(input_shape=[encoded_size]), tfkl.Reshape([1, 1, encoded_size]), tfkl.Conv2DTranspose(2 * base_depth, 7, strides=1, padding='valid', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(2 * base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(2 * base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2D(filters=1, kernel_size=5, strides=1, padding='same', activation=None), tfkl.Flatten(), tfpl.IndependentBernoulli(input_shape, tfd.Bernoulli.logits), ]) vae = tfk.Model(inputs=encoder.inputs, outputs=decoder(encoder.outputs[0])) """#### Do inference.""" negloglik = lambda x, rv_x: -rv_x.log_prob(x) vae.compile(optimizer=tfk.optimizers.Adam(learning_rate=1e-3), loss=negloglik) _ = vae.fit(train_dataset, epochs=15, validation_data=eval_dataset) """### Look Ma, No ~~Hands~~Tensors!""" # We'll just examine ten random digits. x = next(iter(eval_dataset))[0][:10] xhat = vae(x) assert isinstance(xhat, tfd.Distribution) #@title Image Plot Util import matplotlib.pyplot as plt def display_imgs(x, y=None): if not isinstance(x, (np.ndarray, np.generic)): x = np.array(x) plt.ioff() n = x.shape[0] fig, axs = plt.subplots(1, n, figsize=(n, 1)) if y is not None: fig.suptitle(np.argmax(y, axis=1)) for i in range(n): axs.flat[i].imshow(x[i].squeeze(), interpolation='none', cmap='gray') axs.flat[i].axis('off') plt.show() plt.close() plt.ion() print('Originals:') display_imgs(x) print('Decoded Random Samples:') display_imgs(xhat.sample()) print('Decoded Modes:') display_imgs(xhat.mode()) print('Decoded Means:') display_imgs(xhat.mean()) # Now, let's generate ten never-before-seen digits. z = prior.sample(10) xtilde = decoder(z) assert isinstance(xtilde, tfd.Distribution) print('Randomly Generated Samples:') display_imgs(xtilde.sample()) print('Randomly Generated Modes:') display_imgs(xtilde.mode()) print('Randomly Generated Means:') display_imgs(xtilde.mean())