|
a |
|
b/Sine Generation/train.py |
|
|
1 |
""" |
|
|
2 |
GAN with Generator: LSTM/BiLSTM, Discriminator: Convolutional NN with MBD |
|
|
3 |
Sine Data |
|
|
4 |
Introduction |
|
|
5 |
------------ |
|
|
6 |
The aim of this script is to use a convolutional neural network with |
|
|
7 |
a max pooling layer in the discrimiantor. |
|
|
8 |
This was found to work well with the Physionet ECG data in a paper. |
|
|
9 |
They used two convolutional NN so we will compare the difference between the |
|
|
10 |
images generated using a single layer of CNN in the discriminator and 2 CNN layers |
|
|
11 |
to see if this improves the quality of series generated. |
|
|
12 |
""" |
|
|
13 |
""" |
|
|
14 |
Bringing in required dependencies as defined in the GitHub repo: |
|
|
15 |
https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/permutation_test.pyx""" |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
from __future__ import division |
|
|
19 |
|
|
|
20 |
import torch |
|
|
21 |
from tqdm import tqdm |
|
|
22 |
import numpy as np |
|
|
23 |
from matplotlib import pyplot as plt |
|
|
24 |
import seaborn as sns |
|
|
25 |
|
|
|
26 |
from torchvision import transforms |
|
|
27 |
from torch.autograd.variable import Variable |
|
|
28 |
sns.set(rc={'figure.figsize':(11, 4)}) |
|
|
29 |
|
|
|
30 |
import datetime |
|
|
31 |
from datetime import date |
|
|
32 |
today = date.today() |
|
|
33 |
|
|
|
34 |
import random |
|
|
35 |
import json as js |
|
|
36 |
import pickle |
|
|
37 |
import os |
|
|
38 |
|
|
|
39 |
from data import ECGData, PD_to_Tensor |
|
|
40 |
from Model import Generator, Discriminator , noise |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
44 |
|
|
|
45 |
if device == 'cuda:0': |
|
|
46 |
print('Using GPU : ') |
|
|
47 |
print(torch.cuda.get_device_name(device)) |
|
|
48 |
else : |
|
|
49 |
print('Using CPU') |
|
|
50 |
|
|
|
51 |
"""#MMD Evaluation Metric Definition |
|
|
52 |
Using MMD to determine the similarity between distributions |
|
|
53 |
PDIST code comes from torch-two-sample utils code: |
|
|
54 |
https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py |
|
|
55 |
""" |
|
|
56 |
|
|
|
57 |
def pdist(sample_1, sample_2, norm=2, eps=1e-5): |
|
|
58 |
r"""Compute the matrix of all squared pairwise distances. |
|
|
59 |
Arguments |
|
|
60 |
--------- |
|
|
61 |
sample_1 : torch.Tensor or Variable |
|
|
62 |
The first sample, should be of shape ``(n_1, d)``. |
|
|
63 |
sample_2 : torch.Tensor or Variable |
|
|
64 |
The second sample, should be of shape ``(n_2, d)``. |
|
|
65 |
norm : float |
|
|
66 |
The l_p norm to be used. |
|
|
67 |
Returns |
|
|
68 |
------- |
|
|
69 |
torch.Tensor or Variable |
|
|
70 |
Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to |
|
|
71 |
``|| sample_1[i, :] - sample_2[j, :] ||_p``.""" |
|
|
72 |
n_1, n_2 = sample_1.size(0), sample_2.size(0) |
|
|
73 |
norm = float(norm) |
|
|
74 |
|
|
|
75 |
if norm == 2.: |
|
|
76 |
norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True) |
|
|
77 |
norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True) |
|
|
78 |
norms = (norms_1.expand(n_1, n_2) + |
|
|
79 |
norms_2.transpose(0, 1).expand(n_1, n_2)) |
|
|
80 |
distances_squared = norms - 2 * sample_1.mm(sample_2.t()) |
|
|
81 |
return torch.sqrt(eps + torch.abs(distances_squared)) |
|
|
82 |
else: |
|
|
83 |
dim = sample_1.size(1) |
|
|
84 |
expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim) |
|
|
85 |
expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim) |
|
|
86 |
differences = torch.abs(expanded_1 - expanded_2) ** norm |
|
|
87 |
inner = torch.sum(differences, dim=2, keepdim=False) |
|
|
88 |
return (eps + inner) ** (1. / norm) |
|
|
89 |
|
|
|
90 |
def permutation_test_mat(matrix, |
|
|
91 |
n_1, n_2, n_permutations, |
|
|
92 |
a00=1, a11=1, a01=0): |
|
|
93 |
"""Compute the p-value of the following statistic (rejects when high) |
|
|
94 |
\sum_{i,j} a_{\pi(i), \pi(j)} matrix[i, j]. |
|
|
95 |
""" |
|
|
96 |
n = n_1 + n_2 |
|
|
97 |
pi = np.zeros(n, dtype=np.int8) |
|
|
98 |
pi[n_1:] = 1 |
|
|
99 |
|
|
|
100 |
larger = 0. |
|
|
101 |
count = 0 |
|
|
102 |
|
|
|
103 |
for sample_n in range(1 + n_permutations): |
|
|
104 |
count = 0. |
|
|
105 |
for i in range(n): |
|
|
106 |
for j in range(i, n): |
|
|
107 |
mij = matrix[i, j] + matrix[j, i] |
|
|
108 |
if pi[i] == pi[j] == 0: |
|
|
109 |
count += a00 * mij |
|
|
110 |
elif pi[i] == pi[j] == 1: |
|
|
111 |
count += a11 * mij |
|
|
112 |
else: |
|
|
113 |
count += a01 * mij |
|
|
114 |
if sample_n == 0: |
|
|
115 |
statistic = count |
|
|
116 |
elif statistic <= count: |
|
|
117 |
larger += 1 |
|
|
118 |
|
|
|
119 |
np.random.shuffle(pi) |
|
|
120 |
|
|
|
121 |
return larger / n_permutations |
|
|
122 |
|
|
|
123 |
"""Code from Torch-Two-Samples at https://torch-two-sample.readthedocs.io/en/latest/#""" |
|
|
124 |
|
|
|
125 |
class MMDStatistic: |
|
|
126 |
r"""The *unbiased* MMD test of :cite:`gretton2012kernel`. |
|
|
127 |
The kernel used is equal to: |
|
|
128 |
.. math :: |
|
|
129 |
k(x, x') = \sum_{j=1}^k e^{-\alpha_j\|x - x'\|^2}, |
|
|
130 |
for the :math:`\alpha_j` proved in :py:meth:`~.MMDStatistic.__call__`. |
|
|
131 |
Arguments |
|
|
132 |
--------- |
|
|
133 |
n_1: int |
|
|
134 |
The number of points in the first sample. |
|
|
135 |
n_2: int |
|
|
136 |
The number of points in the second sample.""" |
|
|
137 |
|
|
|
138 |
def __init__(self, n_1, n_2): |
|
|
139 |
self.n_1 = n_1 |
|
|
140 |
self.n_2 = n_2 |
|
|
141 |
|
|
|
142 |
# The three constants used in the test. |
|
|
143 |
self.a00 = 1. / (n_1 * (n_1 - 1)) |
|
|
144 |
self.a11 = 1. / (n_2 * (n_2 - 1)) |
|
|
145 |
self.a01 = - 1. / (n_1 * n_2) |
|
|
146 |
|
|
|
147 |
def __call__(self, sample_1, sample_2, alphas, ret_matrix=False): |
|
|
148 |
r"""Evaluate the statistic. |
|
|
149 |
The kernel used is |
|
|
150 |
.. math:: |
|
|
151 |
k(x, x') = \sum_{j=1}^k e^{-\alpha_j \|x - x'\|^2}, |
|
|
152 |
for the provided ``alphas``. |
|
|
153 |
Arguments |
|
|
154 |
--------- |
|
|
155 |
sample_1: :class:`torch:torch.autograd.Variable` |
|
|
156 |
The first sample, of size ``(n_1, d)``. |
|
|
157 |
sample_2: variable of shape (n_2, d) |
|
|
158 |
The second sample, of size ``(n_2, d)``. |
|
|
159 |
alphas : list of :class:`float` |
|
|
160 |
The kernel parameters. |
|
|
161 |
ret_matrix: bool |
|
|
162 |
If set, the call with also return a second variable. |
|
|
163 |
This variable can be then used to compute a p-value using |
|
|
164 |
:py:meth:`~.MMDStatistic.pval`. |
|
|
165 |
Returns |
|
|
166 |
------- |
|
|
167 |
:class:`float` |
|
|
168 |
The test statistic. |
|
|
169 |
:class:`torch:torch.autograd.Variable` |
|
|
170 |
Returned only if ``ret_matrix`` was set to true.""" |
|
|
171 |
sample_12 = torch.cat((sample_1, sample_2), 0) |
|
|
172 |
distances = pdist(sample_12, sample_12, norm=2) |
|
|
173 |
|
|
|
174 |
kernels = None |
|
|
175 |
for alpha in alphas: |
|
|
176 |
kernels_a = torch.exp(- alpha * distances ** 2) |
|
|
177 |
if kernels is None: |
|
|
178 |
kernels = kernels_a |
|
|
179 |
else: |
|
|
180 |
kernels = kernels + kernels_a |
|
|
181 |
|
|
|
182 |
k_1 = kernels[:self.n_1, :self.n_1] |
|
|
183 |
k_2 = kernels[self.n_1:, self.n_1:] |
|
|
184 |
k_12 = kernels[:self.n_1, self.n_1:] |
|
|
185 |
|
|
|
186 |
mmd = (2 * self.a01 * k_12.sum() + |
|
|
187 |
self.a00 * (k_1.sum() - torch.trace(k_1)) + |
|
|
188 |
self.a11 * (k_2.sum() - torch.trace(k_2))) |
|
|
189 |
if ret_matrix: |
|
|
190 |
return mmd, kernels |
|
|
191 |
else: |
|
|
192 |
return mmd |
|
|
193 |
|
|
|
194 |
|
|
|
195 |
def pval(self, distances, n_permutations=1000): |
|
|
196 |
r"""Compute a p-value using a permutation test. |
|
|
197 |
Arguments |
|
|
198 |
--------- |
|
|
199 |
matrix: :class:`torch:torch.autograd.Variable` |
|
|
200 |
The matrix computed using :py:meth:`~.MMDStatistic.__call__`. |
|
|
201 |
n_permutations: int |
|
|
202 |
The number of random draws from the permutation null. |
|
|
203 |
Returns |
|
|
204 |
------- |
|
|
205 |
float |
|
|
206 |
The estimated p-value.""" |
|
|
207 |
if isinstance(distances, Variable): |
|
|
208 |
distances = distances.data |
|
|
209 |
return permutation_test_mat(distances.cpu().numpy(), |
|
|
210 |
self.n_1, self.n_2, |
|
|
211 |
n_permutations, |
|
|
212 |
a00=self.a00, a11=self.a11, a01=self.a01) |
|
|
213 |
|
|
|
214 |
""" |
|
|
215 |
This paper |
|
|
216 |
https://arxiv.org/pdf/1611.04488.pdf says that the most common way to |
|
|
217 |
calculate sigma is to use the median pairwise distances between the joint data. |
|
|
218 |
""" |
|
|
219 |
|
|
|
220 |
def pairwisedistances(X,Y,norm=2): |
|
|
221 |
dist = pdist(X,Y,norm) |
|
|
222 |
return np.median(dist.numpy()) |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
""" |
|
|
226 |
Function for loading Sine Data |
|
|
227 |
""" |
|
|
228 |
|
|
|
229 |
def GetSineData(source_file): |
|
|
230 |
compose = transforms.Compose( |
|
|
231 |
[PD_to_Tensor() |
|
|
232 |
]) |
|
|
233 |
return SineData(source_file ,transform = compose) |
|
|
234 |
|
|
|
235 |
""" |
|
|
236 |
Creating the training set of sine signals |
|
|
237 |
""" |
|
|
238 |
|
|
|
239 |
source_filename = './sinedata_v2.csv' |
|
|
240 |
sine_data = GetSineData(source_file = source_filename) |
|
|
241 |
|
|
|
242 |
sample_size = 50 #batch size needed for Data Loader and the noise creator function. |
|
|
243 |
data_loader = torch.utils.data.DataLoader(sine_data, batch_size=sample_size, shuffle=True) |
|
|
244 |
# Num batches |
|
|
245 |
num_batches = len(data_loader) |
|
|
246 |
|
|
|
247 |
"""Creating the Test Set""" |
|
|
248 |
test_filename = './sinedata_test_v2.csv' |
|
|
249 |
|
|
|
250 |
sine_data_test = GetSineData(source_file = test_filename) |
|
|
251 |
data_loader_test = torch.utils.data.DataLoader(sine_data_test, batch_size=sample_size, shuffle=True) |
|
|
252 |
|
|
|
253 |
"""Defining parameters""" |
|
|
254 |
seq_length = sine_data[0].size()[0] #Number of features |
|
|
255 |
|
|
|
256 |
#Params for the generator |
|
|
257 |
hidden_nodes_g = 50 |
|
|
258 |
layers = 2 |
|
|
259 |
tanh_layer = False |
|
|
260 |
bidir = True |
|
|
261 |
|
|
|
262 |
#No. of training rounds per epoch |
|
|
263 |
D_rounds = 3 |
|
|
264 |
G_rounds = 1 |
|
|
265 |
num_epoch = 120 |
|
|
266 |
learning_rate = 0.0002 |
|
|
267 |
|
|
|
268 |
#Params for the Discriminator |
|
|
269 |
minibatch_layer = 0 |
|
|
270 |
minibatch_normal_init_ = True |
|
|
271 |
num_cvs = 1 |
|
|
272 |
cv1_out= 10 |
|
|
273 |
cv1_k = 3 |
|
|
274 |
cv1_s = 1 |
|
|
275 |
p1_k = 3 |
|
|
276 |
p1_s = 2 |
|
|
277 |
cv2_out = 5 |
|
|
278 |
cv2_k = 3 |
|
|
279 |
cv2_s = 1 |
|
|
280 |
p2_k = 3 |
|
|
281 |
p2_s = 2 |
|
|
282 |
|
|
|
283 |
"""# Evaluation of GAN with 1 CNN Layer in Discriminator |
|
|
284 |
##Generator and Discriminator training phase |
|
|
285 |
""" |
|
|
286 |
|
|
|
287 |
minibatch_out = [0,3,5,8,10] |
|
|
288 |
for minibatch_layer in minibatch_out: |
|
|
289 |
path = ".../your_path/Run_"+str(today.strftime("%d_%m_%Y"))+"_"+ str(datetime.datetime.now().time()).split('.')[0] |
|
|
290 |
os.mkdir(path) |
|
|
291 |
|
|
|
292 |
dict = {'data' : source_filename, |
|
|
293 |
'sample_size' : sample_size, |
|
|
294 |
'seq_length' : seq_length, |
|
|
295 |
'num_layers': layers, |
|
|
296 |
'tanh_layer': tanh_layer, |
|
|
297 |
'bidir': bidir, |
|
|
298 |
'hidden_dims_generator': hidden_nodes_g, |
|
|
299 |
'minibatch_layer': minibatch_layer, |
|
|
300 |
'minibatch_normal_init_' : minibatch_normal_init_, |
|
|
301 |
'num_cvs':num_cvs, |
|
|
302 |
'cv1_out':cv1_out, |
|
|
303 |
'cv1_k':cv1_k, |
|
|
304 |
'cv1_s':cv1_s, |
|
|
305 |
'p1_k':p1_k, |
|
|
306 |
'p1_s':p1_s, |
|
|
307 |
'cv2_out':cv2_out, |
|
|
308 |
'cv2_k':cv2_k, |
|
|
309 |
'cv2_s':cv2_s, |
|
|
310 |
'p2_k':p2_k, |
|
|
311 |
'p2_s':p2_s, |
|
|
312 |
'num_epoch':num_epoch, |
|
|
313 |
'D_rounds': D_rounds, |
|
|
314 |
'G_rounds': G_rounds, |
|
|
315 |
'learning_rate' : learning_rate |
|
|
316 |
} |
|
|
317 |
#Printing the settings used to file |
|
|
318 |
json = js.dumps(dict) |
|
|
319 |
f = open(path+"/settings.json","w") |
|
|
320 |
f.write(json) |
|
|
321 |
f.close() |
|
|
322 |
|
|
|
323 |
#Initialising the generator and discriminator |
|
|
324 |
generator_1 = Generator(seq_length,sample_size,hidden_dim = hidden_nodes_g, tanh_output = tanh_layer, bidirectional = bidir).cuda() |
|
|
325 |
discriminator_1 = Discriminator(seq_length, sample_size ,minibatch_normal_init = minibatch_normal_init_, minibatch = minibatch_layer,num_cv = num_cvs, cv1_out = cv1_out,cv1_k = cv1_k, cv1_s = cv1_s, p1_k = p1_k, p1_s = p1_s, cv2_out= cv2_out, cv2_k = cv2_k, cv2_s = cv2_s, p2_k = p2_k, p2_s = p2_s).cuda() |
|
|
326 |
#Loss function |
|
|
327 |
loss_1 = torch.nn.BCELoss() |
|
|
328 |
|
|
|
329 |
generator_1.train() |
|
|
330 |
discriminator_1.train() |
|
|
331 |
|
|
|
332 |
#Defining optimizer |
|
|
333 |
d_optimizer_1 = torch.optim.Adam(discriminator_1.parameters(),lr = learning_rate) |
|
|
334 |
g_optimizer_1 = torch.optim.Adam(generator_1.parameters(),lr = learning_rate) |
|
|
335 |
|
|
|
336 |
G_losses = [] |
|
|
337 |
D_losses = [] |
|
|
338 |
mmd_list = [] |
|
|
339 |
series_list = np.zeros((1,seq_length)) |
|
|
340 |
|
|
|
341 |
|
|
|
342 |
for n in tqdm(range(num_epoch)): |
|
|
343 |
for n_batch, sample_data in enumerate(data_loader): |
|
|
344 |
|
|
|
345 |
for d in range(D_rounds): |
|
|
346 |
#Train Discriminator on Fake Data |
|
|
347 |
discriminator_1.zero_grad() |
|
|
348 |
|
|
|
349 |
h_g = generator_1.init_hidden() |
|
|
350 |
|
|
|
351 |
#Generating the noise and label data |
|
|
352 |
noise_sample = Variable(noise(len(sample_data),seq_length)) |
|
|
353 |
|
|
|
354 |
#Use this line if generator outputs hidden states: dis_fake_data, (h_g_n,c_g_n) = generator.forward(noise_sample,h_g) |
|
|
355 |
dis_fake_data = generator_1.forward(noise_sample,h_g).detach() |
|
|
356 |
|
|
|
357 |
y_pred_fake = discriminator_1(dis_fake_data) |
|
|
358 |
|
|
|
359 |
loss_fake = loss_1(y_pred_fake,torch.zeros([len(sample_data),1]).cuda()) |
|
|
360 |
loss_fake.backward() |
|
|
361 |
|
|
|
362 |
#Train Discriminator on Real Data |
|
|
363 |
|
|
|
364 |
real_data = Variable(sample_data.float()).cuda() |
|
|
365 |
y_pred_real = discriminator_1.forward(real_data) |
|
|
366 |
|
|
|
367 |
loss_real = loss_1(y_pred_real,torch.ones([len(sample_data),1]).cuda()) |
|
|
368 |
loss_real.backward() |
|
|
369 |
|
|
|
370 |
d_optimizer_1.step() #Updating the weights based on the predictions for both real and fake calculations. |
|
|
371 |
|
|
|
372 |
|
|
|
373 |
|
|
|
374 |
#Train Generator |
|
|
375 |
for g in range(G_rounds): |
|
|
376 |
generator_1.zero_grad() |
|
|
377 |
h_g = generator_1.init_hidden() |
|
|
378 |
|
|
|
379 |
noise_sample = Variable(noise(len(sample_data), seq_length)) |
|
|
380 |
|
|
|
381 |
|
|
|
382 |
#Use this line if generator outputs hidden states: gen_fake_data, (h_g_n,c_g_n) = generator.forward(noise_sample,h_g) |
|
|
383 |
gen_fake_data = generator_1.forward(noise_sample,h_g) |
|
|
384 |
y_pred_gen = discriminator_1(gen_fake_data) |
|
|
385 |
|
|
|
386 |
error_gen = loss_1(y_pred_gen,torch.ones([len(sample_data),1]).cuda()) |
|
|
387 |
error_gen.backward() |
|
|
388 |
g_optimizer_1.step() |
|
|
389 |
|
|
|
390 |
if (n_batch%100 == 0): |
|
|
391 |
print("\nERRORS FOR EPOCH: "+str(n)+"/"+str(num_epoch)+", batch_num: "+str(n_batch)+"/"+str(num_batches)) |
|
|
392 |
print("Discriminator error: "+str(loss_fake+loss_real)) |
|
|
393 |
print("Generator error: "+str(error_gen)) |
|
|
394 |
if n_batch ==( num_batches - 1): |
|
|
395 |
G_losses.append(error_gen.item()) |
|
|
396 |
D_losses.append((loss_real+loss_fake).item()) |
|
|
397 |
|
|
|
398 |
#Saving the parameters of the model to file for each epoch |
|
|
399 |
torch.save(generator_1.state_dict(), path+'/generator_state_'+str(n)+'.pt') |
|
|
400 |
torch.save(discriminator_1.state_dict(),path+ '/discriminator_state_'+str(n)+'.pt') |
|
|
401 |
|
|
|
402 |
# Check how the generator is doing by saving G's output on fixed_noise |
|
|
403 |
|
|
|
404 |
with torch.no_grad(): |
|
|
405 |
h_g = generator_1.init_hidden() |
|
|
406 |
fake = generator_1(noise(len(sample_data), seq_length),h_g).detach().cpu() |
|
|
407 |
generated_sample = torch.zeros(1,seq_length).cuda() |
|
|
408 |
testloader=torch.utils.data.DataLoader(sine_data_test, batch_size=sample_size, shuffle=True) |
|
|
409 |
|
|
|
410 |
|
|
|
411 |
for n_batch, sample_data in enumerate(testloader): |
|
|
412 |
noise_sample_test = noise(sample_size, seq_length) |
|
|
413 |
h_g = generator_1.init_hidden() |
|
|
414 |
generated_data = generator_1.forward(noise_sample_test,h_g).detach().squeeze() |
|
|
415 |
generated_sample = torch.cat((generated_sample,generated_data),dim = 0) |
|
|
416 |
|
|
|
417 |
|
|
|
418 |
# Getting the MMD Statistic for each Training Epoch |
|
|
419 |
generated_sample = generated_sample[1:][:] |
|
|
420 |
sigma = [pairwisedistances(sine_data_test[:].type(torch.DoubleTensor),generated_sample.type(torch.DoubleTensor).squeeze())] |
|
|
421 |
mmd = MMDStatistic(len(sine_data_test[:]),generated_sample.size(0)) |
|
|
422 |
mmd_eval = mmd(sine_data_test[:].type(torch.DoubleTensor),generated_sample.type(torch.DoubleTensor).squeeze(),sigma, ret_matrix=False) |
|
|
423 |
mmd_list.append(mmd_eval.item()) |
|
|
424 |
|
|
|
425 |
|
|
|
426 |
series_list = np.append(series_list,fake[0].numpy().reshape((1,seq_length)),axis=0) |
|
|
427 |
|
|
|
428 |
#Dumping the errors and mmd evaluations for each training epoch. |
|
|
429 |
with open(path+'/generator_losses.txt', 'wb') as fp: |
|
|
430 |
pickle.dump(G_losses, fp) |
|
|
431 |
with open(path+'/discriminator_losses.txt', 'wb') as fp: |
|
|
432 |
pickle.dump(D_losses, fp) |
|
|
433 |
with open(path+'/mmd_list.txt', 'wb') as fp: |
|
|
434 |
pickle.dump(mmd_list, fp) |
|
|
435 |
|
|
|
436 |
#Plotting the error graph |
|
|
437 |
plt.plot(G_losses,'-r',label='Generator Error') |
|
|
438 |
plt.plot(D_losses, '-b', label = 'Discriminator Error') |
|
|
439 |
plt.title('GAN Errors in Training') |
|
|
440 |
plt.legend() |
|
|
441 |
plt.savefig(path+'/GAN_errors.png') |
|
|
442 |
plt.close() |
|
|
443 |
|
|
|
444 |
|
|
|
445 |
#Plot a figure for each training epoch with the MMD value in the title |
|
|
446 |
i = 0 |
|
|
447 |
while i < num_epoch: |
|
|
448 |
if i%3==0: |
|
|
449 |
fig, ax = plt.subplots(3,1,constrained_layout=True) |
|
|
450 |
fig.suptitle("Generated fake data") |
|
|
451 |
for j in range(0,3): |
|
|
452 |
ax[j].plot(series_list[i][:]) |
|
|
453 |
ax[j].set_title('Epoch '+str(i)+ ', MMD: %.4f' % (mmd_list[i])) |
|
|
454 |
i = i+1 |
|
|
455 |
|
|
|
456 |
plt.savefig(path+'/Training_Epoch_Samples_MMD_'+str(i)+'.png') |
|
|
457 |
plt.close(fig) |
|
|
458 |
|
|
|
459 |
|
|
|
460 |
#Checking the diversity of the samples: |
|
|
461 |
generator_1.eval() |
|
|
462 |
h_g = generator_1.init_hidden() |
|
|
463 |
test_noise_sample = noise(sample_size, seq_length) |
|
|
464 |
gen_data= generator_1.forward(test_noise_sample,h_g).detach() |
|
|
465 |
|
|
|
466 |
|
|
|
467 |
plt.title("Generated Sine Waves") |
|
|
468 |
plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-b') |
|
|
469 |
plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-r') |
|
|
470 |
plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-g') |
|
|
471 |
plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-', color = 'orange') |
|
|
472 |
plt.savefig(path+'/Generated_Data_Sample1.png') |
|
|
473 |
plt.close() |
|
|
474 |
|