The goals of this notebook is to learn how to code a variational autoencoder in Keras. We will discuss hyperparameters, training, and loss-functions. In addition, we will familiarize ourselves with the Keras sequential GUI as well as how to visualize results and make predictions using a VAE with a small number of latent dimensions.
This notebook teaches the reader how to build a Variational Autoencoder (VAE) with Keras. The code is a minimally modified, stripped-down version of the code from Lous Tiao in his wonderful blog post which the reader is strongly encouraged to also read.
Our VAE will have Gaussian Latent variables and a Gaussian Posterior distribution $q_\phi({\mathbf z}|{\mathbf x})$ with a diagonal covariance matrix.
Recall, that a VAE consists of four essential elements:
$$-D_{KL}(q_\phi({\bf z}|{\bf x})|p({\bf z}))={1 \over 2} \sum_{j=1}^J \left (1+\log{\sigma_j^2({\bf x})}-\mu_j^2({\bf x}) -\sigma_j^2({\bf x})\right). $$
In the next section of code, we import the data and specify hyperparameters. The MNIST data are gray scale ranging in values from 0 to 255 for each pixel. We normalize this range to lie between 0 and 1.
The hyperparameters we need to specify the architecture and train the VAE are:
intermediate_dim
)latent_dim
)epsilon_std
)batch_size
, epochs
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from keras import backend as K
from keras.layers import (Input, InputLayer, Dense, Lambda, Layer,
Add, Multiply)
from keras.models import Model, Sequential
from keras.datasets import mnist
import pandas as pd
#Load Data and map gray scale 256 to number between zero and 1
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1) / 255.
x_test = np.expand_dims(x_test, axis=-1) / 255.
print(x_train.shape)
# Find dimensions of input images
img_rows, img_cols, img_chns = x_train.shape[1:]
# Specify hyperparameters
original_dim = img_rows * img_cols
intermediate_dim = 256
latent_dim = 2
batch_size = 100
epochs = 3
epsilon_std = 1.0
Here we specify the loss function. The first block of code is just the reconstruction error which is given by the cross-entropy. The second block of code calculates the KL-divergence analytically and adds it to the loss function with the line self.add_loss
. It represents the KL-divergence as just another layer in the neural network with the inputs equal to the outputs: the means and variances for the variational encoder (i.e. $\boldsymbol{\mu}({\bf x})$ and $\boldsymbol{\sigma}^2({\bf x})$).
def nll(y_true, y_pred):
""" Negative log likelihood (Bernoulli). """
# keras.losses.binary_crossentropy gives the mean
# over the last axis. we require the sum
return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
class KLDivergenceLayer(Layer):
""" Identity transform layer that adds KL divergence
to the final model loss.
"""
def __init__(self, *args, **kwargs):
self.is_placeholder = True
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def call(self, inputs):
mu, log_var = inputs
kl_batch = - .5 * K.sum(1 + log_var -
K.square(mu) -
K.exp(log_var), axis=-1)
self.add_loss(K.mean(kl_batch), inputs=inputs)
return inputs
The following specifies both the encoder and decoder. The encoder is a MLP with three layers that maps ${\bf x}$ to $\boldsymbol{\mu}({\bf x})$ and $\boldsymbol{\sigma}^2({\bf x})$, followed by the generation of a latent variable using the reparametrization trick (see main text). The decoder is specified as a single sequential Keras layer.
# Encoder
x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)
z_mu = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)
z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
# Reparametrization trick
z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)
eps = Input(tensor=K.random_normal(shape=(K.shape(x)[0],
latent_dim)))
z_eps = Multiply()([z_sigma, eps])
z = Add()([z_mu, z_eps])
# This defines the Encoder which takes noise and input and outputs
# the latent variable z
encoder = Model(inputs=[x, eps], outputs=z)
# Decoder is MLP specified as single Keras Sequential Layer
decoder = Sequential([
Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
Dense(original_dim, activation='sigmoid')
])
x_pred = decoder(z)
We now train the model. Even though the loss function is the negative log likelihood (cross-entropy), recall that the KL-layer adds the analytic form of the loss function as well. We also have to reshape the data to make it a vector, and specify an optimizer.
vae = Model(inputs=[x, eps], outputs=x_pred, name='vae')
vae.compile(optimizer='rmsprop', loss=nll)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, original_dim) / 255.
x_test = x_test.reshape(-1, original_dim) / 255.
hist = vae.fit(
x_train,
x_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, x_test)
)
We can automatically visualize the loss function as a function of the epoch using the standard Keras interface for fitting.
%matplotlib inline
#for pretty plots
golden_size = lambda width: (width, 2. * width / (1 + np.sqrt(5)))
fig, ax = plt.subplots(figsize=golden_size(6))
hist_df = pd.DataFrame(hist.history)
hist_df.plot(ax=ax)
ax.set_ylabel('NELBO')
ax.set_xlabel('# epochs')
ax.set_ylim(.99*hist_df[1:].values.min(),
1.1*hist_df[1:].values.max())
plt.show()
Since our latent space is two dimensional, we can think of our encoder as defining a dimensional reduction of the original 784 dimensional space to just two dimensions! We can visualize the structure of this mapping by plotting the MNIST dataset in the latent space, with each point colored by which number it is $[0,1,\ldots,9]$.
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=golden_size(6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test, cmap='nipy_spectral')
plt.colorbar()
plt.savefig('VAE_MNIST_latent.pdf')
plt.show()
One of the nice things about VAEs is that they are generative models. Thus, we can generate new examples or fantasy particles much like we did for RBMs and DBMs. We will generate the particles in two different ways
# display a 2D manifold of the images
n = 5 # figure with 15x15 images
quantile_min = 0.01
quantile_max = 0.99
# Linear Sampling
# we will sample n points within [-15, 15] standard deviations
z1_u = np.linspace(5, -5, n)
z2_u = np.linspace(5, -5, n)
z_grid = np.dstack(np.meshgrid(z1_u, z2_u))
x_pred_grid = decoder.predict(z_grid.reshape(n*n, latent_dim)) \
.reshape(n, n, img_rows, img_cols)
# Plot figure
fig, ax = plt.subplots(figsize=golden_size(10))
ax.imshow(np.block(list(map(list, x_pred_grid))), cmap='gray')
ax.set_xticks(np.arange(0, n*img_rows, img_rows) + .5 * img_rows)
ax.set_xticklabels(map('{:.2f}'.format, z1_u), rotation=90)
ax.set_yticks(np.arange(0, n*img_cols, img_cols) + .5 * img_cols)
ax.set_yticklabels(map('{:.2f}'.format, z2_u))
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
ax.set_title('Uniform')
ax.grid(False)
plt.savefig('VAE_MNIST_fantasy_uniform.pdf')
plt.show()
# Inverse CDF sampling
z1 = norm.ppf(np.linspace(quantile_min, quantile_max, n))
z2 = norm.ppf(np.linspace(quantile_max, quantile_min, n))
z_grid2 = np.dstack(np.meshgrid(z1, z2))
x_pred_grid2 = decoder.predict(z_grid2.reshape(n*n, latent_dim)) \
.reshape(n, n, img_rows, img_cols)
# Plot figure Inverse CDF sampling
fig, ax = plt.subplots(figsize=golden_size(10))
ax.imshow(np.block(list(map(list, x_pred_grid2))), cmap='gray')
ax.set_xticks(np.arange(0, n*img_rows, img_rows) + .5 * img_rows)
ax.set_xticklabels(map('{:.2f}'.format, z1), rotation=90)
ax.set_yticks(np.arange(0, n*img_cols, img_cols) + .5 * img_cols)
ax.set_yticklabels(map('{:.2f}'.format, z2))
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
ax.set_title('Inverse CDF')
ax.grid(False)
plt.savefig('VAE_MNIST_fantasy_invCDF.pdf')
plt.show()