Variational Autoencoders

Variational autoencoders view autoencoding from a statistical perspective. Like classical autoencoders, they encode a dataset into a lower dimensional latent space. Additionally, though, variational autoencoders constrain the encoded vectors to roughly follow a probability distribution, e.g. a normal distribution. Here’s an example of a variational autoencoder for the same 1D sequence to sequence monochromatic signal encoding problem.

Variational Autoencoder

This notebook uses the same toy problem as the autoencoding notebook. Here we demonstrate the use of a variational autoencoder.

In [1]:
import numpy as np 
import matplotlib.pyplot as plt
import scipy.stats

from keras.models import Input, Model, load_model
from keras.layers import Dense, LeakyReLU, Lambda
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.utils import plot_model
from keras.losses import mse
from keras import backend as K

from sklearn.model_selection import train_test_split

import petname
Using TensorFlow backend.
In [2]:
# generate training, test, and validation data
n = 4096
nt = 128
f = 3.0                  # frequency in Hz
t = np.linspace(0,1,nt)  # time stamps in s
x = np.zeros((n,nt))
phase = np.random.uniform(-np.pi, np.pi, size=n)
for i in range(n):
    x[i,:] = np.sin(2*np.pi*f*t + phase[i] )
In [3]:
# QC generated data is phase shifted but one frequency
plt.figure(figsize=(8,2))
for i in range(3):
    plt.plot(t,x[np.random.randint(0,nt-1), :])
plt.show()
In [4]:
# QC generated phase in [-pi,pi]
plt.figure(figsize=(8,2))
plt.hist(phase,bins=31)
plt.xlabel('phase')
plt.ylabel('number of occurence')
plt.show()
In [5]:
# split into test, validation, and training sets
x_temp, x_test, _, _ = train_test_split(x, x, test_size=0.05)
x_train, x_valid, _, _ = train_test_split(x_temp,
                                          x_temp,
                                          test_size=0.1)
n_train = len(x_train)
n_valid = len(x_valid)
n_test = len(x_test)
In [6]:
# specify training parameters and callback functions

# batch size for stochastic solver  
batch_size = 16

# number of times entire dataset is considered in stochastic solver
epochs = 100

# unique name for the network for saving
unique_name = petname.name()
model_filename = 'aen_sin_%03dHz_n=%05d_'%(int(f),nt)+unique_name+'.h5'

# training history file name
history_filename = 'results_'+unique_name+'.npz'

# stop early after no improvement past epochs=patience and be verbose
earlystopper = EarlyStopping(patience=100, verbose=1)

# checkpoint and save model when improvement occurs 
checkpointer = ModelCheckpoint(model_filename, verbose=1, save_best_only=True)

# consolidate callback functions for convenience 
callbacks = [earlystopper, checkpointer]

Now things get a bit different from a vanilla autoencoder. First, we set the dimensions of the latent space. For this example we can get away with only one dimension. Intuitively, since the only difference between training examples is the phase, we only need one to encode one dimension.

In [7]:
# encoding dimension; i.e. dimensionality of the latent space
encoding_dim=1

Next we define a function to draw samples from a Gaussian, given the mean and standard deviation. We sample to encode in the latent space. Further, the way this function is defined, it lets us use backpropagation on the mean and standard deviation, even though there's a probabilistic element to this operation (this is the "reparameterization trick").

In [8]:
# define a function to sample from gaussian, given mean and log variance
def sampling(args):
    z_mean, z_log_sigma = args
    epsilon = K.random_normal(shape=(encoding_dim,))
    return z_mean + K.exp(z_log_sigma) * epsilon

Network structure is similar to the autoencoder. The main difference is in the middle of the network: z_mean and z_log_sigma. These layers encode a mean and log(std) that determine the pdf that we draw the encoding in the latent space from. The "Lambda" layer then draws a sample from that pdf, and z is the encoded signal in the latent space.

In [9]:
# construct variational autoencoder network structure

# input layer is full time series of length nt
inputs = Input((nt,))

# encoder hidden layers
encoded = Dense(64)(inputs) 
encoded = LeakyReLU(alpha=0.2)(encoded)
encoded = Dense(32)(encoded)
encoded = LeakyReLU(alpha=0.2)(encoded)
z_mean = Dense(encoding_dim)(encoded)
z_log_sigma = Dense(encoding_dim)(encoded)
z = Lambda(sampling,output_shape=(encoding_dim,))([z_mean,z_log_sigma])

# decoder hidden layers
# explicitly named so we can define the decoder model
#decoder_a = Dense(32)
#decoder_b = Dense(64)
#outputter = Dense(nt,activation='tanh')

decoded = Dense(32)(z)
decoded = LeakyReLU(alpha=0.2)(decoded)
decoded = Dense(64)(decoded)
decoded = LeakyReLU(alpha=0.2)(decoded)
# output layer is same length as input
outputs = Dense(nt,activation='tanh')(decoded)

# consolidate to define autoencoder model inputs and outputs
vae = Model(inputs=inputs, outputs=outputs)

# specify encoder and decoder model for easy encoding and decoding later
encoder = Model(inputs=inputs, outputs=[z_mean,z_log_sigma,z],
                    name='encoder')
# create a placeholder for an encoded input
encoded_input = Input(shape=(encoding_dim,))
# retrieve the last layers of the autoencoder model
decoded_output = vae.layers[-5](encoded_input)
decoded_output = vae.layers[-4](decoded_output)
decoded_output = vae.layers[-3](decoded_output)
decoded_output = vae.layers[-2](decoded_output)
decoded_output = vae.layers[-1](decoded_output)
# create the decoder model
decoder = Model(inputs=encoded_input, outputs=decoded_output, name='decoder')
In [10]:
print('Full autoencoder')
print(vae.summary())
#print('\n Encoder portion of autoencoder')
#print(vae_encoder.summary())
#print('\n Decoder portion of autoencoder')
#print(vae_decoder.summary())
Full autoencoder
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 128)          0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 64)           8256        input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 64)           0           dense_1[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 32)           2080        leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 32)           0           dense_2[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 1)            33          leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 1)            33          leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 1)            0           dense_3[0][0]                    
                                                                 dense_4[0][0]                    
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 32)           64          lambda_1[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 32)           0           dense_5[0][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 64)           2112        leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 64)           0           dense_6[0][0]                    
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 128)          8320        leaky_re_lu_4[0][0]              
==================================================================================================
Total params: 20,898
Trainable params: 20,898
Non-trainable params: 0
__________________________________________________________________________________________________
None

The loss function is another key difference between standard autoencoders and variational autoencoders. A standard autoencoder simply minimizes reconstruction loss. A variational autoencoder minimizes both reconstruction loss and the KL divergence. The KL divergence is a measure of how much two probability distributions differ. Minimizing the KL divergence here means that we are encouraging the latent space encodings to have a normal distribution. The regularization parameter balances between reconstruction loss and enforcing a normal distribution in the latent space.

In [11]:
# specify loss
# regularization balances signal reconstruction with
# a Gaussian distribution in the latent space 
regularization = 10
def vae_loss(input_img, output):
    # compute the average MSE error, then scale it up, ie. simply sum on all axes
    reconstruction_loss = K.sum(K.square(output-input_img))
    kl_loss = - 0.5 * K.sum(1 + z_log_sigma - K.square(z_mean) - K.square(K.exp(z_log_sigma)), axis=-1)
    # return the average loss over all images in batch
    total_loss = K.mean(reconstruction_loss + regularization*kl_loss)    
    return total_loss

vae.compile(optimizer='adam', loss=vae_loss, metrics=['mse'])
In [12]:
# train variational autoencoder
results = vae.fit(x_train,x_train,
                      shuffle=True,
                      batch_size = batch_size, 
                      epochs = epochs,
                      validation_data = (x_valid,x_valid),
                      callbacks = callbacks)
Train on 3501 samples, validate on 390 samples
Epoch 1/100
3501/3501 [==============================] - 1s 213us/step - loss: 572.6987 - mean_squared_error: 0.2705 - val_loss: 494.7158 - val_mean_squared_error: 0.2357

Epoch 00001: val_loss improved from inf to 494.71583, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 2/100
3501/3501 [==============================] - 0s 86us/step - loss: 375.1624 - mean_squared_error: 0.1734 - val_loss: 239.9433 - val_mean_squared_error: 0.1059

Epoch 00002: val_loss improved from 494.71583 to 239.94333, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 3/100
3501/3501 [==============================] - 0s 85us/step - loss: 188.5630 - mean_squared_error: 0.0773 - val_loss: 196.9777 - val_mean_squared_error: 0.0851

Epoch 00003: val_loss improved from 239.94333 to 196.97767, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 4/100
3501/3501 [==============================] - 0s 85us/step - loss: 150.1980 - mean_squared_error: 0.0625 - val_loss: 122.2335 - val_mean_squared_error: 0.0504

Epoch 00004: val_loss improved from 196.97767 to 122.23345, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 5/100
3501/3501 [==============================] - 0s 95us/step - loss: 107.6601 - mean_squared_error: 0.0433 - val_loss: 83.0565 - val_mean_squared_error: 0.0319

Epoch 00005: val_loss improved from 122.23345 to 83.05646, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 6/100
3501/3501 [==============================] - 0s 95us/step - loss: 103.4990 - mean_squared_error: 0.0420 - val_loss: 90.4268 - val_mean_squared_error: 0.0360

Epoch 00006: val_loss did not improve from 83.05646
Epoch 7/100
3501/3501 [==============================] - 0s 91us/step - loss: 94.0652 - mean_squared_error: 0.0377 - val_loss: 69.3914 - val_mean_squared_error: 0.0264

Epoch 00007: val_loss improved from 83.05646 to 69.39142, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 8/100
3501/3501 [==============================] - 0s 90us/step - loss: 90.3890 - mean_squared_error: 0.0359 - val_loss: 76.0256 - val_mean_squared_error: 0.0288

Epoch 00008: val_loss did not improve from 69.39142
Epoch 9/100
3501/3501 [==============================] - 0s 95us/step - loss: 77.4509 - mean_squared_error: 0.0297 - val_loss: 96.6011 - val_mean_squared_error: 0.0389

Epoch 00009: val_loss did not improve from 69.39142
Epoch 10/100
3501/3501 [==============================] - 0s 93us/step - loss: 77.4043 - mean_squared_error: 0.0296 - val_loss: 63.8580 - val_mean_squared_error: 0.0227

Epoch 00010: val_loss improved from 69.39142 to 63.85800, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 11/100
3501/3501 [==============================] - 0s 87us/step - loss: 70.3015 - mean_squared_error: 0.0260 - val_loss: 55.9432 - val_mean_squared_error: 0.0197

Epoch 00011: val_loss improved from 63.85800 to 55.94323, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 12/100
3501/3501 [==============================] - 0s 84us/step - loss: 71.4479 - mean_squared_error: 0.0262 - val_loss: 61.9059 - val_mean_squared_error: 0.0214

Epoch 00012: val_loss did not improve from 55.94323
Epoch 13/100
3501/3501 [==============================] - 0s 84us/step - loss: 80.7295 - mean_squared_error: 0.0305 - val_loss: 77.4515 - val_mean_squared_error: 0.0290

Epoch 00013: val_loss did not improve from 55.94323
Epoch 14/100
3501/3501 [==============================] - 0s 86us/step - loss: 78.9191 - mean_squared_error: 0.0291 - val_loss: 59.7105 - val_mean_squared_error: 0.0202

Epoch 00014: val_loss did not improve from 55.94323
Epoch 15/100
3501/3501 [==============================] - 0s 82us/step - loss: 70.0246 - mean_squared_error: 0.0251 - val_loss: 63.7124 - val_mean_squared_error: 0.0223

Epoch 00015: val_loss did not improve from 55.94323
Epoch 16/100
3501/3501 [==============================] - 0s 85us/step - loss: 62.7666 - mean_squared_error: 0.0216 - val_loss: 76.1868 - val_mean_squared_error: 0.0290

Epoch 00016: val_loss did not improve from 55.94323
Epoch 17/100
3501/3501 [==============================] - 0s 85us/step - loss: 63.8457 - mean_squared_error: 0.0221 - val_loss: 60.5681 - val_mean_squared_error: 0.0233

Epoch 00017: val_loss did not improve from 55.94323
Epoch 18/100
3501/3501 [==============================] - 0s 81us/step - loss: 62.1890 - mean_squared_error: 0.0214 - val_loss: 59.6436 - val_mean_squared_error: 0.0207

Epoch 00018: val_loss did not improve from 55.94323
Epoch 19/100
3501/3501 [==============================] - 0s 86us/step - loss: 56.6178 - mean_squared_error: 0.0189 - val_loss: 48.2721 - val_mean_squared_error: 0.0182

Epoch 00019: val_loss improved from 55.94323 to 48.27212, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 20/100
3501/3501 [==============================] - 0s 82us/step - loss: 62.5801 - mean_squared_error: 0.0214 - val_loss: 110.6859 - val_mean_squared_error: 0.0466

Epoch 00020: val_loss did not improve from 48.27212
Epoch 21/100
3501/3501 [==============================] - 0s 91us/step - loss: 73.1901 - mean_squared_error: 0.0259 - val_loss: 96.0680 - val_mean_squared_error: 0.0349

Epoch 00021: val_loss did not improve from 48.27212
Epoch 22/100
3501/3501 [==============================] - 0s 82us/step - loss: 77.1672 - mean_squared_error: 0.0264 - val_loss: 88.1257 - val_mean_squared_error: 0.0343

Epoch 00022: val_loss did not improve from 48.27212
Epoch 23/100
3501/3501 [==============================] - 0s 93us/step - loss: 95.5942 - mean_squared_error: 0.0357 - val_loss: 88.5442 - val_mean_squared_error: 0.0318

Epoch 00023: val_loss did not improve from 48.27212
Epoch 24/100
3501/3501 [==============================] - 0s 85us/step - loss: 77.6427 - mean_squared_error: 0.0288 - val_loss: 75.1405 - val_mean_squared_error: 0.0284

Epoch 00024: val_loss did not improve from 48.27212
Epoch 25/100
3501/3501 [==============================] - 0s 85us/step - loss: 74.1847 - mean_squared_error: 0.0278 - val_loss: 61.6067 - val_mean_squared_error: 0.0214

Epoch 00025: val_loss did not improve from 48.27212
Epoch 26/100
3501/3501 [==============================] - 0s 98us/step - loss: 70.3399 - mean_squared_error: 0.0255 - val_loss: 71.1625 - val_mean_squared_error: 0.0267

Epoch 00026: val_loss did not improve from 48.27212
Epoch 27/100
3501/3501 [==============================] - 0s 97us/step - loss: 58.3960 - mean_squared_error: 0.0190 - val_loss: 63.5285 - val_mean_squared_error: 0.0209

Epoch 00027: val_loss did not improve from 48.27212
Epoch 28/100
3501/3501 [==============================] - 0s 84us/step - loss: 69.1214 - mean_squared_error: 0.0247 - val_loss: 64.5801 - val_mean_squared_error: 0.0220

Epoch 00028: val_loss did not improve from 48.27212
Epoch 29/100
3501/3501 [==============================] - 0s 84us/step - loss: 57.8912 - mean_squared_error: 0.0191 - val_loss: 71.4037 - val_mean_squared_error: 0.0234

Epoch 00029: val_loss did not improve from 48.27212
Epoch 30/100
3501/3501 [==============================] - 0s 84us/step - loss: 63.8197 - mean_squared_error: 0.0210 - val_loss: 56.9146 - val_mean_squared_error: 0.0180

Epoch 00030: val_loss did not improve from 48.27212
Epoch 31/100
3501/3501 [==============================] - 0s 86us/step - loss: 56.7003 - mean_squared_error: 0.0185 - val_loss: 44.6653 - val_mean_squared_error: 0.0125

Epoch 00031: val_loss improved from 48.27212 to 44.66530, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 32/100
3501/3501 [==============================] - 0s 84us/step - loss: 68.5283 - mean_squared_error: 0.0245 - val_loss: 65.1528 - val_mean_squared_error: 0.0232

Epoch 00032: val_loss did not improve from 44.66530
Epoch 33/100
3501/3501 [==============================] - 0s 80us/step - loss: 66.3796 - mean_squared_error: 0.0239 - val_loss: 41.1790 - val_mean_squared_error: 0.0116

Epoch 00033: val_loss improved from 44.66530 to 41.17902, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 34/100
3501/3501 [==============================] - 0s 88us/step - loss: 56.4764 - mean_squared_error: 0.0191 - val_loss: 46.3807 - val_mean_squared_error: 0.0152

Epoch 00034: val_loss did not improve from 41.17902
Epoch 35/100
3501/3501 [==============================] - 0s 89us/step - loss: 53.3491 - mean_squared_error: 0.0176 - val_loss: 68.8576 - val_mean_squared_error: 0.0241

Epoch 00035: val_loss did not improve from 41.17902
Epoch 36/100
3501/3501 [==============================] - 0s 82us/step - loss: 59.1346 - mean_squared_error: 0.0199 - val_loss: 42.5398 - val_mean_squared_error: 0.0122

Epoch 00036: val_loss did not improve from 41.17902
Epoch 37/100
3501/3501 [==============================] - 0s 94us/step - loss: 59.4343 - mean_squared_error: 0.0201 - val_loss: 106.8294 - val_mean_squared_error: 0.0415

Epoch 00037: val_loss did not improve from 41.17902
Epoch 38/100
3501/3501 [==============================] - 0s 91us/step - loss: 67.6754 - mean_squared_error: 0.0214 - val_loss: 81.3590 - val_mean_squared_error: 0.0301

Epoch 00038: val_loss did not improve from 41.17902
Epoch 39/100
3501/3501 [==============================] - 0s 88us/step - loss: 71.4328 - mean_squared_error: 0.0257 - val_loss: 65.9904 - val_mean_squared_error: 0.0234

Epoch 00039: val_loss did not improve from 41.17902
Epoch 40/100
3501/3501 [==============================] - 0s 82us/step - loss: 53.6078 - mean_squared_error: 0.0167 - val_loss: 56.5115 - val_mean_squared_error: 0.0181

Epoch 00040: val_loss did not improve from 41.17902
Epoch 41/100
3501/3501 [==============================] - 0s 92us/step - loss: 52.1668 - mean_squared_error: 0.0159 - val_loss: 61.3523 - val_mean_squared_error: 0.0209

Epoch 00041: val_loss did not improve from 41.17902
Epoch 42/100
3501/3501 [==============================] - 0s 88us/step - loss: 67.1410 - mean_squared_error: 0.0232 - val_loss: 53.4885 - val_mean_squared_error: 0.0151

Epoch 00042: val_loss did not improve from 41.17902
Epoch 43/100
3501/3501 [==============================] - 0s 83us/step - loss: 60.4681 - mean_squared_error: 0.0200 - val_loss: 55.0059 - val_mean_squared_error: 0.0159

Epoch 00043: val_loss did not improve from 41.17902
Epoch 44/100
3501/3501 [==============================] - 0s 87us/step - loss: 57.7472 - mean_squared_error: 0.0193 - val_loss: 45.8398 - val_mean_squared_error: 0.0138

Epoch 00044: val_loss did not improve from 41.17902
Epoch 45/100
3501/3501 [==============================] - 0s 91us/step - loss: 53.6203 - mean_squared_error: 0.0175 - val_loss: 98.3426 - val_mean_squared_error: 0.0347

Epoch 00045: val_loss did not improve from 41.17902
Epoch 46/100
3501/3501 [==============================] - 0s 89us/step - loss: 60.2016 - mean_squared_error: 0.0195 - val_loss: 43.7904 - val_mean_squared_error: 0.0122

Epoch 00046: val_loss did not improve from 41.17902
Epoch 47/100
3501/3501 [==============================] - 0s 89us/step - loss: 45.7843 - mean_squared_error: 0.0137 - val_loss: 65.1359 - val_mean_squared_error: 0.0231

Epoch 00047: val_loss did not improve from 41.17902
Epoch 48/100
3501/3501 [==============================] - 0s 95us/step - loss: 52.4230 - mean_squared_error: 0.0160 - val_loss: 52.3943 - val_mean_squared_error: 0.0153

Epoch 00048: val_loss did not improve from 41.17902
Epoch 49/100
3501/3501 [==============================] - 0s 90us/step - loss: 60.3948 - mean_squared_error: 0.0205 - val_loss: 67.5678 - val_mean_squared_error: 0.0249

Epoch 00049: val_loss did not improve from 41.17902
Epoch 50/100
3501/3501 [==============================] - 0s 87us/step - loss: 54.3028 - mean_squared_error: 0.0173 - val_loss: 34.3592 - val_mean_squared_error: 0.0087

Epoch 00050: val_loss improved from 41.17902 to 34.35917, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 51/100
3501/3501 [==============================] - 0s 92us/step - loss: 50.0578 - mean_squared_error: 0.0156 - val_loss: 48.7010 - val_mean_squared_error: 0.0151

Epoch 00051: val_loss did not improve from 34.35917
Epoch 52/100
3501/3501 [==============================] - 0s 96us/step - loss: 55.7035 - mean_squared_error: 0.0182 - val_loss: 65.1687 - val_mean_squared_error: 0.0238

Epoch 00052: val_loss did not improve from 34.35917
Epoch 53/100
3501/3501 [==============================] - 0s 88us/step - loss: 47.7227 - mean_squared_error: 0.0145 - val_loss: 48.0725 - val_mean_squared_error: 0.0130

Epoch 00053: val_loss did not improve from 34.35917
Epoch 54/100
3501/3501 [==============================] - 0s 94us/step - loss: 45.5192 - mean_squared_error: 0.0134 - val_loss: 67.9560 - val_mean_squared_error: 0.0258

Epoch 00054: val_loss did not improve from 34.35917
Epoch 55/100
3501/3501 [==============================] - 0s 89us/step - loss: 52.0394 - mean_squared_error: 0.0165 - val_loss: 56.2932 - val_mean_squared_error: 0.0166

Epoch 00055: val_loss did not improve from 34.35917
Epoch 56/100
3501/3501 [==============================] - 0s 96us/step - loss: 56.6267 - mean_squared_error: 0.0185 - val_loss: 54.9323 - val_mean_squared_error: 0.0178

Epoch 00056: val_loss did not improve from 34.35917
Epoch 57/100
3501/3501 [==============================] - 0s 92us/step - loss: 50.9460 - mean_squared_error: 0.0164 - val_loss: 38.6571 - val_mean_squared_error: 0.0108

Epoch 00057: val_loss did not improve from 34.35917
Epoch 58/100
3501/3501 [==============================] - 0s 89us/step - loss: 51.7611 - mean_squared_error: 0.0169 - val_loss: 55.8965 - val_mean_squared_error: 0.0190

Epoch 00058: val_loss did not improve from 34.35917
Epoch 59/100
3501/3501 [==============================] - 0s 99us/step - loss: 46.6769 - mean_squared_error: 0.0136 - val_loss: 52.6628 - val_mean_squared_error: 0.0161

Epoch 00059: val_loss did not improve from 34.35917
Epoch 60/100
3501/3501 [==============================] - 0s 92us/step - loss: 49.5038 - mean_squared_error: 0.0157 - val_loss: 49.0351 - val_mean_squared_error: 0.0148

Epoch 00060: val_loss did not improve from 34.35917
Epoch 61/100
3501/3501 [==============================] - 0s 100us/step - loss: 44.5845 - mean_squared_error: 0.0130 - val_loss: 41.3265 - val_mean_squared_error: 0.0111

Epoch 00061: val_loss did not improve from 34.35917
Epoch 62/100
3501/3501 [==============================] - 0s 89us/step - loss: 44.9681 - mean_squared_error: 0.0133 - val_loss: 50.5358 - val_mean_squared_error: 0.0150

Epoch 00062: val_loss did not improve from 34.35917
Epoch 63/100
3501/3501 [==============================] - 0s 96us/step - loss: 45.5372 - mean_squared_error: 0.0137 - val_loss: 73.5032 - val_mean_squared_error: 0.0270

Epoch 00063: val_loss did not improve from 34.35917
Epoch 64/100
3501/3501 [==============================] - 0s 95us/step - loss: 54.0177 - mean_squared_error: 0.0179 - val_loss: 106.0356 - val_mean_squared_error: 0.0440

Epoch 00064: val_loss did not improve from 34.35917
Epoch 65/100
3501/3501 [==============================] - 0s 89us/step - loss: 47.1428 - mean_squared_error: 0.0146 - val_loss: 95.2780 - val_mean_squared_error: 0.0395

Epoch 00065: val_loss did not improve from 34.35917
Epoch 66/100
3501/3501 [==============================] - 0s 90us/step - loss: 59.2134 - mean_squared_error: 0.0204 - val_loss: 68.6873 - val_mean_squared_error: 0.0244

Epoch 00066: val_loss did not improve from 34.35917
Epoch 67/100
3501/3501 [==============================] - 0s 96us/step - loss: 48.6599 - mean_squared_error: 0.0149 - val_loss: 50.7604 - val_mean_squared_error: 0.0158

Epoch 00067: val_loss did not improve from 34.35917
Epoch 68/100
3501/3501 [==============================] - 0s 95us/step - loss: 49.0834 - mean_squared_error: 0.0157 - val_loss: 94.0569 - val_mean_squared_error: 0.0375

Epoch 00068: val_loss did not improve from 34.35917
Epoch 69/100
3501/3501 [==============================] - 0s 94us/step - loss: 44.7533 - mean_squared_error: 0.0137 - val_loss: 55.5269 - val_mean_squared_error: 0.0185

Epoch 00069: val_loss did not improve from 34.35917
Epoch 70/100
3501/3501 [==============================] - 0s 90us/step - loss: 45.2486 - mean_squared_error: 0.0135 - val_loss: 39.4384 - val_mean_squared_error: 0.0112

Epoch 00070: val_loss did not improve from 34.35917
Epoch 71/100
3501/3501 [==============================] - 0s 93us/step - loss: 43.2798 - mean_squared_error: 0.0127 - val_loss: 50.4140 - val_mean_squared_error: 0.0162

Epoch 00071: val_loss did not improve from 34.35917
Epoch 72/100
3501/3501 [==============================] - 0s 85us/step - loss: 40.9322 - mean_squared_error: 0.0114 - val_loss: 45.7334 - val_mean_squared_error: 0.0137

Epoch 00072: val_loss did not improve from 34.35917
Epoch 73/100
3501/3501 [==============================] - 0s 86us/step - loss: 42.1467 - mean_squared_error: 0.0123 - val_loss: 57.9899 - val_mean_squared_error: 0.0226

Epoch 00073: val_loss did not improve from 34.35917
Epoch 74/100
3501/3501 [==============================] - 0s 94us/step - loss: 43.6932 - mean_squared_error: 0.0130 - val_loss: 37.1455 - val_mean_squared_error: 0.0100

Epoch 00074: val_loss did not improve from 34.35917
Epoch 75/100
3501/3501 [==============================] - 0s 90us/step - loss: 41.5662 - mean_squared_error: 0.0118 - val_loss: 36.2863 - val_mean_squared_error: 0.0117

Epoch 00075: val_loss did not improve from 34.35917
Epoch 76/100
3501/3501 [==============================] - 0s 92us/step - loss: 46.4817 - mean_squared_error: 0.0140 - val_loss: 45.2458 - val_mean_squared_error: 0.0143

Epoch 00076: val_loss did not improve from 34.35917
Epoch 77/100
3501/3501 [==============================] - 0s 93us/step - loss: 44.4898 - mean_squared_error: 0.0133 - val_loss: 33.5857 - val_mean_squared_error: 0.0093

Epoch 00077: val_loss improved from 34.35917 to 33.58571, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 78/100
3501/3501 [==============================] - 0s 103us/step - loss: 48.6680 - mean_squared_error: 0.0153 - val_loss: 53.6324 - val_mean_squared_error: 0.0177

Epoch 00078: val_loss did not improve from 33.58571
Epoch 79/100
3501/3501 [==============================] - 0s 90us/step - loss: 45.3376 - mean_squared_error: 0.0134 - val_loss: 43.4611 - val_mean_squared_error: 0.0130

Epoch 00079: val_loss did not improve from 33.58571
Epoch 80/100
3501/3501 [==============================] - 0s 93us/step - loss: 46.1038 - mean_squared_error: 0.0140 - val_loss: 39.0098 - val_mean_squared_error: 0.0101

Epoch 00080: val_loss did not improve from 33.58571
Epoch 81/100
3501/3501 [==============================] - 0s 93us/step - loss: 52.8147 - mean_squared_error: 0.0169 - val_loss: 36.5853 - val_mean_squared_error: 0.0086

Epoch 00081: val_loss did not improve from 33.58571
Epoch 82/100
3501/3501 [==============================] - 0s 93us/step - loss: 47.0342 - mean_squared_error: 0.0145 - val_loss: 36.4819 - val_mean_squared_error: 0.0085

Epoch 00082: val_loss did not improve from 33.58571
Epoch 83/100
3501/3501 [==============================] - 0s 96us/step - loss: 40.4257 - mean_squared_error: 0.0111 - val_loss: 48.5772 - val_mean_squared_error: 0.0154

Epoch 00083: val_loss did not improve from 33.58571
Epoch 84/100
3501/3501 [==============================] - 0s 93us/step - loss: 42.8431 - mean_squared_error: 0.0122 - val_loss: 27.9591 - val_mean_squared_error: 0.0072

Epoch 00084: val_loss improved from 33.58571 to 27.95906, saving model to aen_sin_003Hz_n=00128_corgi.h5
Epoch 85/100
3501/3501 [==============================] - 0s 93us/step - loss: 35.2204 - mean_squared_error: 0.0086 - val_loss: 35.1120 - val_mean_squared_error: 0.0119

Epoch 00085: val_loss did not improve from 27.95906
Epoch 86/100
3501/3501 [==============================] - 0s 100us/step - loss: 36.7214 - mean_squared_error: 0.0092 - val_loss: 37.8777 - val_mean_squared_error: 0.0104

Epoch 00086: val_loss did not improve from 27.95906
Epoch 87/100
3501/3501 [==============================] - 0s 94us/step - loss: 42.0163 - mean_squared_error: 0.0119 - val_loss: 37.8636 - val_mean_squared_error: 0.0102

Epoch 00087: val_loss did not improve from 27.95906
Epoch 88/100
3501/3501 [==============================] - 0s 86us/step - loss: 36.6429 - mean_squared_error: 0.0097 - val_loss: 34.5471 - val_mean_squared_error: 0.0107

Epoch 00088: val_loss did not improve from 27.95906
Epoch 89/100
3501/3501 [==============================] - 0s 94us/step - loss: 35.6986 - mean_squared_error: 0.0089 - val_loss: 35.7121 - val_mean_squared_error: 0.0091

Epoch 00089: val_loss did not improve from 27.95906
Epoch 90/100
3501/3501 [==============================] - 0s 96us/step - loss: 47.0230 - mean_squared_error: 0.0144 - val_loss: 66.8385 - val_mean_squared_error: 0.0234

Epoch 00090: val_loss did not improve from 27.95906
Epoch 91/100
3501/3501 [==============================] - 0s 91us/step - loss: 47.1317 - mean_squared_error: 0.0142 - val_loss: 43.3211 - val_mean_squared_error: 0.0132

Epoch 00091: val_loss did not improve from 27.95906
Epoch 92/100
3501/3501 [==============================] - 0s 95us/step - loss: 44.9765 - mean_squared_error: 0.0133 - val_loss: 49.7922 - val_mean_squared_error: 0.0167

Epoch 00092: val_loss did not improve from 27.95906
Epoch 93/100
3501/3501 [==============================] - 0s 102us/step - loss: 43.3008 - mean_squared_error: 0.0126 - val_loss: 34.1182 - val_mean_squared_error: 0.0089

Epoch 00093: val_loss did not improve from 27.95906
Epoch 94/100
3501/3501 [==============================] - 0s 94us/step - loss: 40.7637 - mean_squared_error: 0.0113 - val_loss: 52.4609 - val_mean_squared_error: 0.0177

Epoch 00094: val_loss did not improve from 27.95906
Epoch 95/100
3501/3501 [==============================] - 0s 85us/step - loss: 38.3485 - mean_squared_error: 0.0102 - val_loss: 35.5032 - val_mean_squared_error: 0.0097

Epoch 00095: val_loss did not improve from 27.95906
Epoch 96/100
3501/3501 [==============================] - 0s 87us/step - loss: 38.5916 - mean_squared_error: 0.0102 - val_loss: 32.2546 - val_mean_squared_error: 0.0102

Epoch 00096: val_loss did not improve from 27.95906
Epoch 97/100
3501/3501 [==============================] - 0s 92us/step - loss: 46.0180 - mean_squared_error: 0.0136 - val_loss: 33.4016 - val_mean_squared_error: 0.0048

Epoch 00097: val_loss did not improve from 27.95906
Epoch 98/100
3501/3501 [==============================] - 0s 92us/step - loss: 33.3350 - mean_squared_error: 0.0071 - val_loss: 35.8605 - val_mean_squared_error: 0.0088

Epoch 00098: val_loss did not improve from 27.95906
Epoch 99/100
3501/3501 [==============================] - 0s 86us/step - loss: 47.4055 - mean_squared_error: 0.0145 - val_loss: 67.7614 - val_mean_squared_error: 0.0238

Epoch 00099: val_loss did not improve from 27.95906
Epoch 100/100
3501/3501 [==============================] - 0s 91us/step - loss: 62.3434 - mean_squared_error: 0.0192 - val_loss: 41.4646 - val_mean_squared_error: 0.0117

Epoch 00100: val_loss did not improve from 27.95906
In [13]:
# QC training and validation curves (should follow eachother)
plt.figure(figsize=(8,2))
plt.plot(results.history['val_loss'], label='val')
plt.plot(results.history['loss'], label='train')
plt.xlabel('epoch index')
plt.ylabel('loss value (MSE)')
plt.legend()
plt.show()
In [14]:
encoded_test = np.array(encoder.predict(x_test))
vae_test = vae.predict(x_test)

First, let's check to see how well the encoder is working.

In [15]:
plt.plot(x_test[0],label='true signal')
plt.plot(vae_test[0],label='encoded-decoded signal')
plt.legend()
plt.show()

Now let's look at the distribution of samples in the latent space, to see how gaussian it is.

In [16]:
import matplotlib.mlab as mlab
n,bins,patches=plt.hist(encoded_test[2,:,:].flatten(),bins=10,density=True)
plt.plot(bins,scipy.stats.norm.pdf(bins,0,1))
plt.show()

Now let's see how the decoded signal depends on the latent vector.

In [17]:
from matplotlib.pylab import cm
plt.figure(figsize=(10,4))
latent_inputs = np.repeat(np.linspace(-3,3,11)[:,np.newaxis],encoding_dim,axis=1)
decoded_latent_inputs = decoder.predict(latent_inputs)
colors = cm.viridis(np.linspace(0,1,len(latent_inputs)))
for i,latent_input in enumerate(decoded_latent_inputs):
    plt.plot(t,latent_input,color=colors[i])
labels = ["{0:.1f}".format(l) for l in latent_inputs[:,0]]
plt.legend(labels)
plt.show()

Interestingly, the autoencoder has learned to use the latent vector (value in this case, since we specified an encoding dimension of 1) as a proxy for phase. As the latent value changes, the phase of the decoded signal changes. Latent values near zero reproduce a sine wave well, while values far from zero produce signals that aren't exactly sinusoidal.

In [ ]:
 

Comments