Variational Autoencoders¶
Variational autoencoders view autoencoding from a statistical perspective. Like classical autoencoders, they encode a dataset into a lower dimensional latent space. Additionally, though, variational autoencoders constrain the encoded vectors to roughly follow a probability distribution, e.g. a normal distribution. Here’s an example of a variational autoencoder for the same 1D sequence to sequence monochromatic signal encoding problem.
Variational Autoencoder¶
This notebook uses the same toy problem as the autoencoding notebook. Here we demonstrate the use of a variational autoencoder.
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
from keras.models import Input, Model, load_model
from keras.layers import Dense, LeakyReLU, Lambda
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.utils import plot_model
from keras.losses import mse
from keras import backend as K
from sklearn.model_selection import train_test_split
import petname
Using TensorFlow backend.
# generate training, test, and validation data
n = 4096
nt = 128
f = 3.0 # frequency in Hz
t = np.linspace(0,1,nt) # time stamps in s
x = np.zeros((n,nt))
phase = np.random.uniform(-np.pi, np.pi, size=n)
for i in range(n):
x[i,:] = np.sin(2*np.pi*f*t + phase[i] )
# QC generated data is phase shifted but one frequency
plt.figure(figsize=(8,2))
for i in range(3):
plt.plot(t,x[np.random.randint(0,nt-1), :])
plt.show()
# QC generated phase in [-pi,pi]
plt.figure(figsize=(8,2))
plt.hist(phase,bins=31)
plt.xlabel('phase')
plt.ylabel('number of occurence')
plt.show()
# split into test, validation, and training sets
x_temp, x_test, _, _ = train_test_split(x, x, test_size=0.05)
x_train, x_valid, _, _ = train_test_split(x_temp,
x_temp,
test_size=0.1)
n_train = len(x_train)
n_valid = len(x_valid)
n_test = len(x_test)
# specify training parameters and callback functions
# batch size for stochastic solver
batch_size = 16
# number of times entire dataset is considered in stochastic solver
epochs = 100
# unique name for the network for saving
unique_name = petname.name()
model_filename = 'aen_sin_%03dHz_n=%05d_'%(int(f),nt)+unique_name+'.h5'
# training history file name
history_filename = 'results_'+unique_name+'.npz'
# stop early after no improvement past epochs=patience and be verbose
earlystopper = EarlyStopping(patience=100, verbose=1)
# checkpoint and save model when improvement occurs
checkpointer = ModelCheckpoint(model_filename, verbose=1, save_best_only=True)
# consolidate callback functions for convenience
callbacks = [earlystopper, checkpointer]
Now things get a bit different from a vanilla autoencoder. First, we set the dimensions of the latent space. For this example we can get away with only one dimension. Intuitively, since the only difference between training examples is the phase, we only need one to encode one dimension.
# encoding dimension; i.e. dimensionality of the latent space
encoding_dim=1
Next we define a function to draw samples from a Gaussian, given the mean and standard deviation. We sample to encode in the latent space. Further, the way this function is defined, it lets us use backpropagation on the mean and standard deviation, even though there's a probabilistic element to this operation (this is the "reparameterization trick").
# define a function to sample from gaussian, given mean and log variance
def sampling(args):
z_mean, z_log_sigma = args
epsilon = K.random_normal(shape=(encoding_dim,))
return z_mean + K.exp(z_log_sigma) * epsilon
Network structure is similar to the autoencoder. The main difference is in the middle of the network: z_mean and z_log_sigma. These layers encode a mean and log(std) that determine the pdf that we draw the encoding in the latent space from. The "Lambda" layer then draws a sample from that pdf, and z is the encoded signal in the latent space.
# construct variational autoencoder network structure
# input layer is full time series of length nt
inputs = Input((nt,))
# encoder hidden layers
encoded = Dense(64)(inputs)
encoded = LeakyReLU(alpha=0.2)(encoded)
encoded = Dense(32)(encoded)
encoded = LeakyReLU(alpha=0.2)(encoded)
z_mean = Dense(encoding_dim)(encoded)
z_log_sigma = Dense(encoding_dim)(encoded)
z = Lambda(sampling,output_shape=(encoding_dim,))([z_mean,z_log_sigma])
# decoder hidden layers
# explicitly named so we can define the decoder model
#decoder_a = Dense(32)
#decoder_b = Dense(64)
#outputter = Dense(nt,activation='tanh')
decoded = Dense(32)(z)
decoded = LeakyReLU(alpha=0.2)(decoded)
decoded = Dense(64)(decoded)
decoded = LeakyReLU(alpha=0.2)(decoded)
# output layer is same length as input
outputs = Dense(nt,activation='tanh')(decoded)
# consolidate to define autoencoder model inputs and outputs
vae = Model(inputs=inputs, outputs=outputs)
# specify encoder and decoder model for easy encoding and decoding later
encoder = Model(inputs=inputs, outputs=[z_mean,z_log_sigma,z],
name='encoder')
# create a placeholder for an encoded input
encoded_input = Input(shape=(encoding_dim,))
# retrieve the last layers of the autoencoder model
decoded_output = vae.layers[-5](encoded_input)
decoded_output = vae.layers[-4](decoded_output)
decoded_output = vae.layers[-3](decoded_output)
decoded_output = vae.layers[-2](decoded_output)
decoded_output = vae.layers[-1](decoded_output)
# create the decoder model
decoder = Model(inputs=encoded_input, outputs=decoded_output, name='decoder')
print('Full autoencoder')
print(vae.summary())
#print('\n Encoder portion of autoencoder')
#print(vae_encoder.summary())
#print('\n Decoder portion of autoencoder')
#print(vae_decoder.summary())
Full autoencoder __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, 128) 0 __________________________________________________________________________________________________ dense_1 (Dense) (None, 64) 8256 input_1[0][0] __________________________________________________________________________________________________ leaky_re_lu_1 (LeakyReLU) (None, 64) 0 dense_1[0][0] __________________________________________________________________________________________________ dense_2 (Dense) (None, 32) 2080 leaky_re_lu_1[0][0] __________________________________________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 32) 0 dense_2[0][0] __________________________________________________________________________________________________ dense_3 (Dense) (None, 1) 33 leaky_re_lu_2[0][0] __________________________________________________________________________________________________ dense_4 (Dense) (None, 1) 33 leaky_re_lu_2[0][0] __________________________________________________________________________________________________ lambda_1 (Lambda) (None, 1) 0 dense_3[0][0] dense_4[0][0] __________________________________________________________________________________________________ dense_5 (Dense) (None, 32) 64 lambda_1[0][0] __________________________________________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 32) 0 dense_5[0][0] __________________________________________________________________________________________________ dense_6 (Dense) (None, 64) 2112 leaky_re_lu_3[0][0] __________________________________________________________________________________________________ leaky_re_lu_4 (LeakyReLU) (None, 64) 0 dense_6[0][0] __________________________________________________________________________________________________ dense_7 (Dense) (None, 128) 8320 leaky_re_lu_4[0][0] ================================================================================================== Total params: 20,898 Trainable params: 20,898 Non-trainable params: 0 __________________________________________________________________________________________________ None
The loss function is another key difference between standard autoencoders and variational autoencoders. A standard autoencoder simply minimizes reconstruction loss. A variational autoencoder minimizes both reconstruction loss and the KL divergence. The KL divergence is a measure of how much two probability distributions differ. Minimizing the KL divergence here means that we are encouraging the latent space encodings to have a normal distribution. The regularization parameter balances between reconstruction loss and enforcing a normal distribution in the latent space.
# specify loss
# regularization balances signal reconstruction with
# a Gaussian distribution in the latent space
regularization = 10
def vae_loss(input_img, output):
# compute the average MSE error, then scale it up, ie. simply sum on all axes
reconstruction_loss = K.sum(K.square(output-input_img))
kl_loss = - 0.5 * K.sum(1 + z_log_sigma - K.square(z_mean) - K.square(K.exp(z_log_sigma)), axis=-1)
# return the average loss over all images in batch
total_loss = K.mean(reconstruction_loss + regularization*kl_loss)
return total_loss
vae.compile(optimizer='adam', loss=vae_loss, metrics=['mse'])
# train variational autoencoder
results = vae.fit(x_train,x_train,
shuffle=True,
batch_size = batch_size,
epochs = epochs,
validation_data = (x_valid,x_valid),
callbacks = callbacks)
Train on 3501 samples, validate on 390 samples Epoch 1/100 3501/3501 [==============================] - 1s 213us/step - loss: 572.6987 - mean_squared_error: 0.2705 - val_loss: 494.7158 - val_mean_squared_error: 0.2357 Epoch 00001: val_loss improved from inf to 494.71583, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 2/100 3501/3501 [==============================] - 0s 86us/step - loss: 375.1624 - mean_squared_error: 0.1734 - val_loss: 239.9433 - val_mean_squared_error: 0.1059 Epoch 00002: val_loss improved from 494.71583 to 239.94333, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 3/100 3501/3501 [==============================] - 0s 85us/step - loss: 188.5630 - mean_squared_error: 0.0773 - val_loss: 196.9777 - val_mean_squared_error: 0.0851 Epoch 00003: val_loss improved from 239.94333 to 196.97767, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 4/100 3501/3501 [==============================] - 0s 85us/step - loss: 150.1980 - mean_squared_error: 0.0625 - val_loss: 122.2335 - val_mean_squared_error: 0.0504 Epoch 00004: val_loss improved from 196.97767 to 122.23345, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 5/100 3501/3501 [==============================] - 0s 95us/step - loss: 107.6601 - mean_squared_error: 0.0433 - val_loss: 83.0565 - val_mean_squared_error: 0.0319 Epoch 00005: val_loss improved from 122.23345 to 83.05646, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 6/100 3501/3501 [==============================] - 0s 95us/step - loss: 103.4990 - mean_squared_error: 0.0420 - val_loss: 90.4268 - val_mean_squared_error: 0.0360 Epoch 00006: val_loss did not improve from 83.05646 Epoch 7/100 3501/3501 [==============================] - 0s 91us/step - loss: 94.0652 - mean_squared_error: 0.0377 - val_loss: 69.3914 - val_mean_squared_error: 0.0264 Epoch 00007: val_loss improved from 83.05646 to 69.39142, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 8/100 3501/3501 [==============================] - 0s 90us/step - loss: 90.3890 - mean_squared_error: 0.0359 - val_loss: 76.0256 - val_mean_squared_error: 0.0288 Epoch 00008: val_loss did not improve from 69.39142 Epoch 9/100 3501/3501 [==============================] - 0s 95us/step - loss: 77.4509 - mean_squared_error: 0.0297 - val_loss: 96.6011 - val_mean_squared_error: 0.0389 Epoch 00009: val_loss did not improve from 69.39142 Epoch 10/100 3501/3501 [==============================] - 0s 93us/step - loss: 77.4043 - mean_squared_error: 0.0296 - val_loss: 63.8580 - val_mean_squared_error: 0.0227 Epoch 00010: val_loss improved from 69.39142 to 63.85800, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 11/100 3501/3501 [==============================] - 0s 87us/step - loss: 70.3015 - mean_squared_error: 0.0260 - val_loss: 55.9432 - val_mean_squared_error: 0.0197 Epoch 00011: val_loss improved from 63.85800 to 55.94323, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 12/100 3501/3501 [==============================] - 0s 84us/step - loss: 71.4479 - mean_squared_error: 0.0262 - val_loss: 61.9059 - val_mean_squared_error: 0.0214 Epoch 00012: val_loss did not improve from 55.94323 Epoch 13/100 3501/3501 [==============================] - 0s 84us/step - loss: 80.7295 - mean_squared_error: 0.0305 - val_loss: 77.4515 - val_mean_squared_error: 0.0290 Epoch 00013: val_loss did not improve from 55.94323 Epoch 14/100 3501/3501 [==============================] - 0s 86us/step - loss: 78.9191 - mean_squared_error: 0.0291 - val_loss: 59.7105 - val_mean_squared_error: 0.0202 Epoch 00014: val_loss did not improve from 55.94323 Epoch 15/100 3501/3501 [==============================] - 0s 82us/step - loss: 70.0246 - mean_squared_error: 0.0251 - val_loss: 63.7124 - val_mean_squared_error: 0.0223 Epoch 00015: val_loss did not improve from 55.94323 Epoch 16/100 3501/3501 [==============================] - 0s 85us/step - loss: 62.7666 - mean_squared_error: 0.0216 - val_loss: 76.1868 - val_mean_squared_error: 0.0290 Epoch 00016: val_loss did not improve from 55.94323 Epoch 17/100 3501/3501 [==============================] - 0s 85us/step - loss: 63.8457 - mean_squared_error: 0.0221 - val_loss: 60.5681 - val_mean_squared_error: 0.0233 Epoch 00017: val_loss did not improve from 55.94323 Epoch 18/100 3501/3501 [==============================] - 0s 81us/step - loss: 62.1890 - mean_squared_error: 0.0214 - val_loss: 59.6436 - val_mean_squared_error: 0.0207 Epoch 00018: val_loss did not improve from 55.94323 Epoch 19/100 3501/3501 [==============================] - 0s 86us/step - loss: 56.6178 - mean_squared_error: 0.0189 - val_loss: 48.2721 - val_mean_squared_error: 0.0182 Epoch 00019: val_loss improved from 55.94323 to 48.27212, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 20/100 3501/3501 [==============================] - 0s 82us/step - loss: 62.5801 - mean_squared_error: 0.0214 - val_loss: 110.6859 - val_mean_squared_error: 0.0466 Epoch 00020: val_loss did not improve from 48.27212 Epoch 21/100 3501/3501 [==============================] - 0s 91us/step - loss: 73.1901 - mean_squared_error: 0.0259 - val_loss: 96.0680 - val_mean_squared_error: 0.0349 Epoch 00021: val_loss did not improve from 48.27212 Epoch 22/100 3501/3501 [==============================] - 0s 82us/step - loss: 77.1672 - mean_squared_error: 0.0264 - val_loss: 88.1257 - val_mean_squared_error: 0.0343 Epoch 00022: val_loss did not improve from 48.27212 Epoch 23/100 3501/3501 [==============================] - 0s 93us/step - loss: 95.5942 - mean_squared_error: 0.0357 - val_loss: 88.5442 - val_mean_squared_error: 0.0318 Epoch 00023: val_loss did not improve from 48.27212 Epoch 24/100 3501/3501 [==============================] - 0s 85us/step - loss: 77.6427 - mean_squared_error: 0.0288 - val_loss: 75.1405 - val_mean_squared_error: 0.0284 Epoch 00024: val_loss did not improve from 48.27212 Epoch 25/100 3501/3501 [==============================] - 0s 85us/step - loss: 74.1847 - mean_squared_error: 0.0278 - val_loss: 61.6067 - val_mean_squared_error: 0.0214 Epoch 00025: val_loss did not improve from 48.27212 Epoch 26/100 3501/3501 [==============================] - 0s 98us/step - loss: 70.3399 - mean_squared_error: 0.0255 - val_loss: 71.1625 - val_mean_squared_error: 0.0267 Epoch 00026: val_loss did not improve from 48.27212 Epoch 27/100 3501/3501 [==============================] - 0s 97us/step - loss: 58.3960 - mean_squared_error: 0.0190 - val_loss: 63.5285 - val_mean_squared_error: 0.0209 Epoch 00027: val_loss did not improve from 48.27212 Epoch 28/100 3501/3501 [==============================] - 0s 84us/step - loss: 69.1214 - mean_squared_error: 0.0247 - val_loss: 64.5801 - val_mean_squared_error: 0.0220 Epoch 00028: val_loss did not improve from 48.27212 Epoch 29/100 3501/3501 [==============================] - 0s 84us/step - loss: 57.8912 - mean_squared_error: 0.0191 - val_loss: 71.4037 - val_mean_squared_error: 0.0234 Epoch 00029: val_loss did not improve from 48.27212 Epoch 30/100 3501/3501 [==============================] - 0s 84us/step - loss: 63.8197 - mean_squared_error: 0.0210 - val_loss: 56.9146 - val_mean_squared_error: 0.0180 Epoch 00030: val_loss did not improve from 48.27212 Epoch 31/100 3501/3501 [==============================] - 0s 86us/step - loss: 56.7003 - mean_squared_error: 0.0185 - val_loss: 44.6653 - val_mean_squared_error: 0.0125 Epoch 00031: val_loss improved from 48.27212 to 44.66530, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 32/100 3501/3501 [==============================] - 0s 84us/step - loss: 68.5283 - mean_squared_error: 0.0245 - val_loss: 65.1528 - val_mean_squared_error: 0.0232 Epoch 00032: val_loss did not improve from 44.66530 Epoch 33/100 3501/3501 [==============================] - 0s 80us/step - loss: 66.3796 - mean_squared_error: 0.0239 - val_loss: 41.1790 - val_mean_squared_error: 0.0116 Epoch 00033: val_loss improved from 44.66530 to 41.17902, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 34/100 3501/3501 [==============================] - 0s 88us/step - loss: 56.4764 - mean_squared_error: 0.0191 - val_loss: 46.3807 - val_mean_squared_error: 0.0152 Epoch 00034: val_loss did not improve from 41.17902 Epoch 35/100 3501/3501 [==============================] - 0s 89us/step - loss: 53.3491 - mean_squared_error: 0.0176 - val_loss: 68.8576 - val_mean_squared_error: 0.0241 Epoch 00035: val_loss did not improve from 41.17902 Epoch 36/100 3501/3501 [==============================] - 0s 82us/step - loss: 59.1346 - mean_squared_error: 0.0199 - val_loss: 42.5398 - val_mean_squared_error: 0.0122 Epoch 00036: val_loss did not improve from 41.17902 Epoch 37/100 3501/3501 [==============================] - 0s 94us/step - loss: 59.4343 - mean_squared_error: 0.0201 - val_loss: 106.8294 - val_mean_squared_error: 0.0415 Epoch 00037: val_loss did not improve from 41.17902 Epoch 38/100 3501/3501 [==============================] - 0s 91us/step - loss: 67.6754 - mean_squared_error: 0.0214 - val_loss: 81.3590 - val_mean_squared_error: 0.0301 Epoch 00038: val_loss did not improve from 41.17902 Epoch 39/100 3501/3501 [==============================] - 0s 88us/step - loss: 71.4328 - mean_squared_error: 0.0257 - val_loss: 65.9904 - val_mean_squared_error: 0.0234 Epoch 00039: val_loss did not improve from 41.17902 Epoch 40/100 3501/3501 [==============================] - 0s 82us/step - loss: 53.6078 - mean_squared_error: 0.0167 - val_loss: 56.5115 - val_mean_squared_error: 0.0181 Epoch 00040: val_loss did not improve from 41.17902 Epoch 41/100 3501/3501 [==============================] - 0s 92us/step - loss: 52.1668 - mean_squared_error: 0.0159 - val_loss: 61.3523 - val_mean_squared_error: 0.0209 Epoch 00041: val_loss did not improve from 41.17902 Epoch 42/100 3501/3501 [==============================] - 0s 88us/step - loss: 67.1410 - mean_squared_error: 0.0232 - val_loss: 53.4885 - val_mean_squared_error: 0.0151 Epoch 00042: val_loss did not improve from 41.17902 Epoch 43/100 3501/3501 [==============================] - 0s 83us/step - loss: 60.4681 - mean_squared_error: 0.0200 - val_loss: 55.0059 - val_mean_squared_error: 0.0159 Epoch 00043: val_loss did not improve from 41.17902 Epoch 44/100 3501/3501 [==============================] - 0s 87us/step - loss: 57.7472 - mean_squared_error: 0.0193 - val_loss: 45.8398 - val_mean_squared_error: 0.0138 Epoch 00044: val_loss did not improve from 41.17902 Epoch 45/100 3501/3501 [==============================] - 0s 91us/step - loss: 53.6203 - mean_squared_error: 0.0175 - val_loss: 98.3426 - val_mean_squared_error: 0.0347 Epoch 00045: val_loss did not improve from 41.17902 Epoch 46/100 3501/3501 [==============================] - 0s 89us/step - loss: 60.2016 - mean_squared_error: 0.0195 - val_loss: 43.7904 - val_mean_squared_error: 0.0122 Epoch 00046: val_loss did not improve from 41.17902 Epoch 47/100 3501/3501 [==============================] - 0s 89us/step - loss: 45.7843 - mean_squared_error: 0.0137 - val_loss: 65.1359 - val_mean_squared_error: 0.0231 Epoch 00047: val_loss did not improve from 41.17902 Epoch 48/100 3501/3501 [==============================] - 0s 95us/step - loss: 52.4230 - mean_squared_error: 0.0160 - val_loss: 52.3943 - val_mean_squared_error: 0.0153 Epoch 00048: val_loss did not improve from 41.17902 Epoch 49/100 3501/3501 [==============================] - 0s 90us/step - loss: 60.3948 - mean_squared_error: 0.0205 - val_loss: 67.5678 - val_mean_squared_error: 0.0249 Epoch 00049: val_loss did not improve from 41.17902 Epoch 50/100 3501/3501 [==============================] - 0s 87us/step - loss: 54.3028 - mean_squared_error: 0.0173 - val_loss: 34.3592 - val_mean_squared_error: 0.0087 Epoch 00050: val_loss improved from 41.17902 to 34.35917, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 51/100 3501/3501 [==============================] - 0s 92us/step - loss: 50.0578 - mean_squared_error: 0.0156 - val_loss: 48.7010 - val_mean_squared_error: 0.0151 Epoch 00051: val_loss did not improve from 34.35917 Epoch 52/100 3501/3501 [==============================] - 0s 96us/step - loss: 55.7035 - mean_squared_error: 0.0182 - val_loss: 65.1687 - val_mean_squared_error: 0.0238 Epoch 00052: val_loss did not improve from 34.35917 Epoch 53/100 3501/3501 [==============================] - 0s 88us/step - loss: 47.7227 - mean_squared_error: 0.0145 - val_loss: 48.0725 - val_mean_squared_error: 0.0130 Epoch 00053: val_loss did not improve from 34.35917 Epoch 54/100 3501/3501 [==============================] - 0s 94us/step - loss: 45.5192 - mean_squared_error: 0.0134 - val_loss: 67.9560 - val_mean_squared_error: 0.0258 Epoch 00054: val_loss did not improve from 34.35917 Epoch 55/100 3501/3501 [==============================] - 0s 89us/step - loss: 52.0394 - mean_squared_error: 0.0165 - val_loss: 56.2932 - val_mean_squared_error: 0.0166 Epoch 00055: val_loss did not improve from 34.35917 Epoch 56/100 3501/3501 [==============================] - 0s 96us/step - loss: 56.6267 - mean_squared_error: 0.0185 - val_loss: 54.9323 - val_mean_squared_error: 0.0178 Epoch 00056: val_loss did not improve from 34.35917 Epoch 57/100 3501/3501 [==============================] - 0s 92us/step - loss: 50.9460 - mean_squared_error: 0.0164 - val_loss: 38.6571 - val_mean_squared_error: 0.0108 Epoch 00057: val_loss did not improve from 34.35917 Epoch 58/100 3501/3501 [==============================] - 0s 89us/step - loss: 51.7611 - mean_squared_error: 0.0169 - val_loss: 55.8965 - val_mean_squared_error: 0.0190 Epoch 00058: val_loss did not improve from 34.35917 Epoch 59/100 3501/3501 [==============================] - 0s 99us/step - loss: 46.6769 - mean_squared_error: 0.0136 - val_loss: 52.6628 - val_mean_squared_error: 0.0161 Epoch 00059: val_loss did not improve from 34.35917 Epoch 60/100 3501/3501 [==============================] - 0s 92us/step - loss: 49.5038 - mean_squared_error: 0.0157 - val_loss: 49.0351 - val_mean_squared_error: 0.0148 Epoch 00060: val_loss did not improve from 34.35917 Epoch 61/100 3501/3501 [==============================] - 0s 100us/step - loss: 44.5845 - mean_squared_error: 0.0130 - val_loss: 41.3265 - val_mean_squared_error: 0.0111 Epoch 00061: val_loss did not improve from 34.35917 Epoch 62/100 3501/3501 [==============================] - 0s 89us/step - loss: 44.9681 - mean_squared_error: 0.0133 - val_loss: 50.5358 - val_mean_squared_error: 0.0150 Epoch 00062: val_loss did not improve from 34.35917 Epoch 63/100 3501/3501 [==============================] - 0s 96us/step - loss: 45.5372 - mean_squared_error: 0.0137 - val_loss: 73.5032 - val_mean_squared_error: 0.0270 Epoch 00063: val_loss did not improve from 34.35917 Epoch 64/100 3501/3501 [==============================] - 0s 95us/step - loss: 54.0177 - mean_squared_error: 0.0179 - val_loss: 106.0356 - val_mean_squared_error: 0.0440 Epoch 00064: val_loss did not improve from 34.35917 Epoch 65/100 3501/3501 [==============================] - 0s 89us/step - loss: 47.1428 - mean_squared_error: 0.0146 - val_loss: 95.2780 - val_mean_squared_error: 0.0395 Epoch 00065: val_loss did not improve from 34.35917 Epoch 66/100 3501/3501 [==============================] - 0s 90us/step - loss: 59.2134 - mean_squared_error: 0.0204 - val_loss: 68.6873 - val_mean_squared_error: 0.0244 Epoch 00066: val_loss did not improve from 34.35917 Epoch 67/100 3501/3501 [==============================] - 0s 96us/step - loss: 48.6599 - mean_squared_error: 0.0149 - val_loss: 50.7604 - val_mean_squared_error: 0.0158 Epoch 00067: val_loss did not improve from 34.35917 Epoch 68/100 3501/3501 [==============================] - 0s 95us/step - loss: 49.0834 - mean_squared_error: 0.0157 - val_loss: 94.0569 - val_mean_squared_error: 0.0375 Epoch 00068: val_loss did not improve from 34.35917 Epoch 69/100 3501/3501 [==============================] - 0s 94us/step - loss: 44.7533 - mean_squared_error: 0.0137 - val_loss: 55.5269 - val_mean_squared_error: 0.0185 Epoch 00069: val_loss did not improve from 34.35917 Epoch 70/100 3501/3501 [==============================] - 0s 90us/step - loss: 45.2486 - mean_squared_error: 0.0135 - val_loss: 39.4384 - val_mean_squared_error: 0.0112 Epoch 00070: val_loss did not improve from 34.35917 Epoch 71/100 3501/3501 [==============================] - 0s 93us/step - loss: 43.2798 - mean_squared_error: 0.0127 - val_loss: 50.4140 - val_mean_squared_error: 0.0162 Epoch 00071: val_loss did not improve from 34.35917 Epoch 72/100 3501/3501 [==============================] - 0s 85us/step - loss: 40.9322 - mean_squared_error: 0.0114 - val_loss: 45.7334 - val_mean_squared_error: 0.0137 Epoch 00072: val_loss did not improve from 34.35917 Epoch 73/100 3501/3501 [==============================] - 0s 86us/step - loss: 42.1467 - mean_squared_error: 0.0123 - val_loss: 57.9899 - val_mean_squared_error: 0.0226 Epoch 00073: val_loss did not improve from 34.35917 Epoch 74/100 3501/3501 [==============================] - 0s 94us/step - loss: 43.6932 - mean_squared_error: 0.0130 - val_loss: 37.1455 - val_mean_squared_error: 0.0100 Epoch 00074: val_loss did not improve from 34.35917 Epoch 75/100 3501/3501 [==============================] - 0s 90us/step - loss: 41.5662 - mean_squared_error: 0.0118 - val_loss: 36.2863 - val_mean_squared_error: 0.0117 Epoch 00075: val_loss did not improve from 34.35917 Epoch 76/100 3501/3501 [==============================] - 0s 92us/step - loss: 46.4817 - mean_squared_error: 0.0140 - val_loss: 45.2458 - val_mean_squared_error: 0.0143 Epoch 00076: val_loss did not improve from 34.35917 Epoch 77/100 3501/3501 [==============================] - 0s 93us/step - loss: 44.4898 - mean_squared_error: 0.0133 - val_loss: 33.5857 - val_mean_squared_error: 0.0093 Epoch 00077: val_loss improved from 34.35917 to 33.58571, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 78/100 3501/3501 [==============================] - 0s 103us/step - loss: 48.6680 - mean_squared_error: 0.0153 - val_loss: 53.6324 - val_mean_squared_error: 0.0177 Epoch 00078: val_loss did not improve from 33.58571 Epoch 79/100 3501/3501 [==============================] - 0s 90us/step - loss: 45.3376 - mean_squared_error: 0.0134 - val_loss: 43.4611 - val_mean_squared_error: 0.0130 Epoch 00079: val_loss did not improve from 33.58571 Epoch 80/100 3501/3501 [==============================] - 0s 93us/step - loss: 46.1038 - mean_squared_error: 0.0140 - val_loss: 39.0098 - val_mean_squared_error: 0.0101 Epoch 00080: val_loss did not improve from 33.58571 Epoch 81/100 3501/3501 [==============================] - 0s 93us/step - loss: 52.8147 - mean_squared_error: 0.0169 - val_loss: 36.5853 - val_mean_squared_error: 0.0086 Epoch 00081: val_loss did not improve from 33.58571 Epoch 82/100 3501/3501 [==============================] - 0s 93us/step - loss: 47.0342 - mean_squared_error: 0.0145 - val_loss: 36.4819 - val_mean_squared_error: 0.0085 Epoch 00082: val_loss did not improve from 33.58571 Epoch 83/100 3501/3501 [==============================] - 0s 96us/step - loss: 40.4257 - mean_squared_error: 0.0111 - val_loss: 48.5772 - val_mean_squared_error: 0.0154 Epoch 00083: val_loss did not improve from 33.58571 Epoch 84/100 3501/3501 [==============================] - 0s 93us/step - loss: 42.8431 - mean_squared_error: 0.0122 - val_loss: 27.9591 - val_mean_squared_error: 0.0072 Epoch 00084: val_loss improved from 33.58571 to 27.95906, saving model to aen_sin_003Hz_n=00128_corgi.h5 Epoch 85/100 3501/3501 [==============================] - 0s 93us/step - loss: 35.2204 - mean_squared_error: 0.0086 - val_loss: 35.1120 - val_mean_squared_error: 0.0119 Epoch 00085: val_loss did not improve from 27.95906 Epoch 86/100 3501/3501 [==============================] - 0s 100us/step - loss: 36.7214 - mean_squared_error: 0.0092 - val_loss: 37.8777 - val_mean_squared_error: 0.0104 Epoch 00086: val_loss did not improve from 27.95906 Epoch 87/100 3501/3501 [==============================] - 0s 94us/step - loss: 42.0163 - mean_squared_error: 0.0119 - val_loss: 37.8636 - val_mean_squared_error: 0.0102 Epoch 00087: val_loss did not improve from 27.95906 Epoch 88/100 3501/3501 [==============================] - 0s 86us/step - loss: 36.6429 - mean_squared_error: 0.0097 - val_loss: 34.5471 - val_mean_squared_error: 0.0107 Epoch 00088: val_loss did not improve from 27.95906 Epoch 89/100 3501/3501 [==============================] - 0s 94us/step - loss: 35.6986 - mean_squared_error: 0.0089 - val_loss: 35.7121 - val_mean_squared_error: 0.0091 Epoch 00089: val_loss did not improve from 27.95906 Epoch 90/100 3501/3501 [==============================] - 0s 96us/step - loss: 47.0230 - mean_squared_error: 0.0144 - val_loss: 66.8385 - val_mean_squared_error: 0.0234 Epoch 00090: val_loss did not improve from 27.95906 Epoch 91/100 3501/3501 [==============================] - 0s 91us/step - loss: 47.1317 - mean_squared_error: 0.0142 - val_loss: 43.3211 - val_mean_squared_error: 0.0132 Epoch 00091: val_loss did not improve from 27.95906 Epoch 92/100 3501/3501 [==============================] - 0s 95us/step - loss: 44.9765 - mean_squared_error: 0.0133 - val_loss: 49.7922 - val_mean_squared_error: 0.0167 Epoch 00092: val_loss did not improve from 27.95906 Epoch 93/100 3501/3501 [==============================] - 0s 102us/step - loss: 43.3008 - mean_squared_error: 0.0126 - val_loss: 34.1182 - val_mean_squared_error: 0.0089 Epoch 00093: val_loss did not improve from 27.95906 Epoch 94/100 3501/3501 [==============================] - 0s 94us/step - loss: 40.7637 - mean_squared_error: 0.0113 - val_loss: 52.4609 - val_mean_squared_error: 0.0177 Epoch 00094: val_loss did not improve from 27.95906 Epoch 95/100 3501/3501 [==============================] - 0s 85us/step - loss: 38.3485 - mean_squared_error: 0.0102 - val_loss: 35.5032 - val_mean_squared_error: 0.0097 Epoch 00095: val_loss did not improve from 27.95906 Epoch 96/100 3501/3501 [==============================] - 0s 87us/step - loss: 38.5916 - mean_squared_error: 0.0102 - val_loss: 32.2546 - val_mean_squared_error: 0.0102 Epoch 00096: val_loss did not improve from 27.95906 Epoch 97/100 3501/3501 [==============================] - 0s 92us/step - loss: 46.0180 - mean_squared_error: 0.0136 - val_loss: 33.4016 - val_mean_squared_error: 0.0048 Epoch 00097: val_loss did not improve from 27.95906 Epoch 98/100 3501/3501 [==============================] - 0s 92us/step - loss: 33.3350 - mean_squared_error: 0.0071 - val_loss: 35.8605 - val_mean_squared_error: 0.0088 Epoch 00098: val_loss did not improve from 27.95906 Epoch 99/100 3501/3501 [==============================] - 0s 86us/step - loss: 47.4055 - mean_squared_error: 0.0145 - val_loss: 67.7614 - val_mean_squared_error: 0.0238 Epoch 00099: val_loss did not improve from 27.95906 Epoch 100/100 3501/3501 [==============================] - 0s 91us/step - loss: 62.3434 - mean_squared_error: 0.0192 - val_loss: 41.4646 - val_mean_squared_error: 0.0117 Epoch 00100: val_loss did not improve from 27.95906
# QC training and validation curves (should follow eachother)
plt.figure(figsize=(8,2))
plt.plot(results.history['val_loss'], label='val')
plt.plot(results.history['loss'], label='train')
plt.xlabel('epoch index')
plt.ylabel('loss value (MSE)')
plt.legend()
plt.show()
encoded_test = np.array(encoder.predict(x_test))
vae_test = vae.predict(x_test)
First, let's check to see how well the encoder is working.
plt.plot(x_test[0],label='true signal')
plt.plot(vae_test[0],label='encoded-decoded signal')
plt.legend()
plt.show()
Now let's look at the distribution of samples in the latent space, to see how gaussian it is.
import matplotlib.mlab as mlab
n,bins,patches=plt.hist(encoded_test[2,:,:].flatten(),bins=10,density=True)
plt.plot(bins,scipy.stats.norm.pdf(bins,0,1))
plt.show()
Now let's see how the decoded signal depends on the latent vector.
from matplotlib.pylab import cm
plt.figure(figsize=(10,4))
latent_inputs = np.repeat(np.linspace(-3,3,11)[:,np.newaxis],encoding_dim,axis=1)
decoded_latent_inputs = decoder.predict(latent_inputs)
colors = cm.viridis(np.linspace(0,1,len(latent_inputs)))
for i,latent_input in enumerate(decoded_latent_inputs):
plt.plot(t,latent_input,color=colors[i])
labels = ["{0:.1f}".format(l) for l in latent_inputs[:,0]]
plt.legend(labels)
plt.show()
Interestingly, the autoencoder has learned to use the latent vector (value in this case, since we specified an encoding dimension of 1) as a proxy for phase. As the latent value changes, the phase of the decoded signal changes. Latent values near zero reproduce a sine wave well, while values far from zero produce signals that aren't exactly sinusoidal.