a b/Sine Generation/train.py
1
"""
2
GAN with Generator: LSTM/BiLSTM, Discriminator: Convolutional NN with MBD 
3
Sine Data
4
 Introduction
5
 ------------
6
    The aim of this script is to use a convolutional neural network with 
7
    a max pooling layer in the discrimiantor. 
8
    This was found to work well with the Physionet ECG data in a paper. 
9
    They used two convolutional NN so we will compare the difference between the 
10
    images generated using a single layer of CNN in the discriminator and 2 CNN layers 
11
    to see if this improves the quality of series generated.
12
"""
13
"""
14
Bringing in required dependencies as defined in the GitHub repo: 
15
    https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/permutation_test.pyx"""
16
17
18
from __future__ import division
19
20
import torch
21
from tqdm import tqdm
22
import numpy as np
23
from matplotlib import pyplot as plt
24
import seaborn as sns
25
26
from torchvision import transforms
27
from torch.autograd.variable import Variable
28
sns.set(rc={'figure.figsize':(11, 4)})
29
30
import datetime 
31
from datetime import date
32
today = date.today()
33
34
import random
35
import json as js
36
import pickle
37
import os
38
39
from data import ECGData, PD_to_Tensor
40
from Model import Generator, Discriminator , noise
41
42
43
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
44
45
if device == 'cuda:0':
46
    print('Using GPU : ')
47
    print(torch.cuda.get_device_name(device))
48
else :
49
    print('Using CPU')
50
51
"""#MMD Evaluation Metric Definition
52
Using MMD to determine the similarity between distributions
53
PDIST code comes from torch-two-sample utils code: 
54
    https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py
55
"""
56
57
def pdist(sample_1, sample_2, norm=2, eps=1e-5):
58
    r"""Compute the matrix of all squared pairwise distances.
59
    Arguments
60
    ---------
61
    sample_1 : torch.Tensor or Variable
62
        The first sample, should be of shape ``(n_1, d)``.
63
    sample_2 : torch.Tensor or Variable
64
        The second sample, should be of shape ``(n_2, d)``.
65
    norm : float
66
        The l_p norm to be used.
67
    Returns
68
    -------
69
    torch.Tensor or Variable
70
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
71
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
72
    n_1, n_2 = sample_1.size(0), sample_2.size(0)
73
    norm = float(norm)
74
    
75
    if norm == 2.:
76
        norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
77
        norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
78
        norms = (norms_1.expand(n_1, n_2) +
79
                 norms_2.transpose(0, 1).expand(n_1, n_2))
80
        distances_squared = norms - 2 * sample_1.mm(sample_2.t())
81
        return torch.sqrt(eps + torch.abs(distances_squared))
82
    else:
83
        dim = sample_1.size(1)
84
        expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim)
85
        expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim)
86
        differences = torch.abs(expanded_1 - expanded_2) ** norm
87
        inner = torch.sum(differences, dim=2, keepdim=False)
88
        return (eps + inner) ** (1. / norm)
89
90
def permutation_test_mat(matrix,
91
                         n_1,  n_2,  n_permutations,
92
                          a00=1,  a11=1,  a01=0):
93
    """Compute the p-value of the following statistic (rejects when high)
94
        \sum_{i,j} a_{\pi(i), \pi(j)} matrix[i, j].
95
    """
96
    n = n_1 + n_2
97
    pi = np.zeros(n, dtype=np.int8)
98
    pi[n_1:] = 1
99
100
    larger = 0.
101
    count = 0
102
    
103
    for sample_n in range(1 + n_permutations):
104
        count = 0.
105
        for i in range(n):
106
            for j in range(i, n):
107
                mij = matrix[i, j] + matrix[j, i]
108
                if pi[i] == pi[j] == 0:
109
                    count += a00 * mij
110
                elif pi[i] == pi[j] == 1:
111
                    count += a11 * mij
112
                else:
113
                    count += a01 * mij
114
        if sample_n == 0:
115
            statistic = count
116
        elif statistic <= count:
117
            larger += 1
118
119
        np.random.shuffle(pi)
120
121
    return larger / n_permutations
122
123
"""Code from Torch-Two-Samples at https://torch-two-sample.readthedocs.io/en/latest/#"""
124
125
class MMDStatistic:
126
    r"""The *unbiased* MMD test of :cite:`gretton2012kernel`.
127
    The kernel used is equal to:
128
    .. math ::
129
        k(x, x') = \sum_{j=1}^k e^{-\alpha_j\|x - x'\|^2},
130
    for the :math:`\alpha_j` proved in :py:meth:`~.MMDStatistic.__call__`.
131
    Arguments
132
    ---------
133
    n_1: int
134
        The number of points in the first sample.
135
    n_2: int
136
        The number of points in the second sample."""
137
138
    def __init__(self, n_1, n_2):
139
        self.n_1 = n_1
140
        self.n_2 = n_2
141
142
        # The three constants used in the test.
143
        self.a00 = 1. / (n_1 * (n_1 - 1))
144
        self.a11 = 1. / (n_2 * (n_2 - 1))
145
        self.a01 = - 1. / (n_1 * n_2)
146
147
    def __call__(self, sample_1, sample_2, alphas, ret_matrix=False):
148
        r"""Evaluate the statistic.
149
        The kernel used is
150
        .. math::
151
            k(x, x') = \sum_{j=1}^k e^{-\alpha_j \|x - x'\|^2},
152
        for the provided ``alphas``.
153
        Arguments
154
        ---------
155
        sample_1: :class:`torch:torch.autograd.Variable`
156
            The first sample, of size ``(n_1, d)``.
157
        sample_2: variable of shape (n_2, d)
158
            The second sample, of size ``(n_2, d)``.
159
        alphas : list of :class:`float`
160
            The kernel parameters.
161
        ret_matrix: bool
162
            If set, the call with also return a second variable.
163
            This variable can be then used to compute a p-value using
164
            :py:meth:`~.MMDStatistic.pval`.
165
        Returns
166
        -------
167
        :class:`float`
168
            The test statistic.
169
        :class:`torch:torch.autograd.Variable`
170
            Returned only if ``ret_matrix`` was set to true."""
171
        sample_12 = torch.cat((sample_1, sample_2), 0)
172
        distances = pdist(sample_12, sample_12, norm=2)
173
174
        kernels = None
175
        for alpha in alphas:
176
            kernels_a = torch.exp(- alpha * distances ** 2)
177
            if kernels is None:
178
                kernels = kernels_a
179
            else:
180
                kernels = kernels + kernels_a
181
182
        k_1 = kernels[:self.n_1, :self.n_1]
183
        k_2 = kernels[self.n_1:, self.n_1:]
184
        k_12 = kernels[:self.n_1, self.n_1:]
185
186
        mmd = (2 * self.a01 * k_12.sum() +
187
               self.a00 * (k_1.sum() - torch.trace(k_1)) +
188
               self.a11 * (k_2.sum() - torch.trace(k_2)))
189
        if ret_matrix:
190
            return mmd, kernels
191
        else:
192
            return mmd
193
194
195
    def pval(self, distances, n_permutations=1000):
196
        r"""Compute a p-value using a permutation test.
197
        Arguments
198
        ---------
199
        matrix: :class:`torch:torch.autograd.Variable`
200
            The matrix computed using :py:meth:`~.MMDStatistic.__call__`.
201
        n_permutations: int
202
            The number of random draws from the permutation null.
203
        Returns
204
        -------
205
        float
206
            The estimated p-value."""
207
        if isinstance(distances, Variable):
208
            distances = distances.data
209
        return permutation_test_mat(distances.cpu().numpy(),
210
                                    self.n_1, self.n_2,
211
                                    n_permutations,
212
                                    a00=self.a00, a11=self.a11, a01=self.a01)
213
214
"""
215
This paper 
216
https://arxiv.org/pdf/1611.04488.pdf says that the most common way to 
217
calculate sigma is to use the median pairwise distances between the joint data.
218
"""
219
220
def pairwisedistances(X,Y,norm=2):
221
    dist = pdist(X,Y,norm)
222
    return np.median(dist.numpy())
223
224
225
""" 
226
Function for loading Sine Data 
227
"""
228
229
def GetSineData(source_file):
230
  compose = transforms.Compose(
231
        [PD_to_Tensor()
232
        ])
233
  return SineData(source_file ,transform = compose)
234
  
235
  """
236
Creating the training set of sine signals
237
"""
238
239
source_filename =  './sinedata_v2.csv'
240
sine_data = GetSineData(source_file = source_filename)
241
242
sample_size = 50 #batch size needed for Data Loader and the noise creator function.
243
data_loader = torch.utils.data.DataLoader(sine_data, batch_size=sample_size, shuffle=True)
244
# Num batches
245
num_batches = len(data_loader)
246
247
"""Creating the Test Set"""
248
test_filename =  './sinedata_test_v2.csv'
249
250
sine_data_test = GetSineData(source_file = test_filename)
251
data_loader_test = torch.utils.data.DataLoader(sine_data_test, batch_size=sample_size, shuffle=True)
252
253
"""Defining parameters"""
254
seq_length = sine_data[0].size()[0] #Number of features
255
256
#Params for the generator
257
hidden_nodes_g = 50
258
layers = 2
259
tanh_layer = False
260
bidir = True
261
262
#No. of training rounds per epoch
263
D_rounds = 3
264
G_rounds = 1
265
num_epoch = 120
266
learning_rate = 0.0002
267
    
268
#Params for the Discriminator
269
minibatch_layer = 0
270
minibatch_normal_init_ = True
271
num_cvs = 1
272
cv1_out= 10
273
cv1_k = 3
274
cv1_s = 1
275
p1_k = 3
276
p1_s = 2
277
cv2_out = 5
278
cv2_k = 3
279
cv2_s = 1
280
p2_k = 3
281
p2_s = 2
282
283
"""# Evaluation of GAN with 1 CNN Layer in Discriminator
284
##Generator and Discriminator training phase
285
"""
286
287
minibatch_out = [0,3,5,8,10]
288
for minibatch_layer in minibatch_out:
289
  path = ".../your_path/Run_"+str(today.strftime("%d_%m_%Y"))+"_"+ str(datetime.datetime.now().time()).split('.')[0]
290
  os.mkdir(path)
291
292
  dict = {'data' : source_filename, 
293
          'sample_size' : sample_size, 
294
          'seq_length' : seq_length,
295
          'num_layers': layers, 
296
          'tanh_layer': tanh_layer,
297
          'bidir': bidir,
298
          'hidden_dims_generator': hidden_nodes_g, 
299
          'minibatch_layer': minibatch_layer,
300
          'minibatch_normal_init_' : minibatch_normal_init_,
301
          'num_cvs':num_cvs,
302
          'cv1_out':cv1_out,
303
          'cv1_k':cv1_k,
304
          'cv1_s':cv1_s,
305
          'p1_k':p1_k,
306
          'p1_s':p1_s,
307
          'cv2_out':cv2_out,
308
          'cv2_k':cv2_k,
309
          'cv2_s':cv2_s,
310
          'p2_k':p2_k,
311
          'p2_s':p2_s,
312
          'num_epoch':num_epoch,
313
          'D_rounds': D_rounds,
314
          'G_rounds': G_rounds,  
315
          'learning_rate' : learning_rate
316
         }
317
  #Printing the settings used to file
318
  json = js.dumps(dict)
319
  f = open(path+"/settings.json","w")
320
  f.write(json)
321
  f.close()
322
  
323
  #Initialising the generator and discriminator
324
  generator_1 = Generator(seq_length,sample_size,hidden_dim =  hidden_nodes_g, tanh_output = tanh_layer, bidirectional = bidir).cuda()
325
  discriminator_1 = Discriminator(seq_length, sample_size ,minibatch_normal_init = minibatch_normal_init_, minibatch = minibatch_layer,num_cv = num_cvs, cv1_out = cv1_out,cv1_k = cv1_k, cv1_s = cv1_s, p1_k = p1_k, p1_s = p1_s, cv2_out= cv2_out, cv2_k = cv2_k, cv2_s = cv2_s, p2_k = p2_k, p2_s = p2_s).cuda()
326
  #Loss function 
327
  loss_1 = torch.nn.BCELoss()
328
329
  generator_1.train()
330
  discriminator_1.train()
331
  
332
  #Defining optimizer
333
  d_optimizer_1 = torch.optim.Adam(discriminator_1.parameters(),lr = learning_rate)
334
  g_optimizer_1 = torch.optim.Adam(generator_1.parameters(),lr = learning_rate)
335
336
  G_losses = []
337
  D_losses = []
338
  mmd_list = []
339
  series_list = np.zeros((1,seq_length))
340
341
342
  for n in tqdm(range(num_epoch)):
343
      for n_batch, sample_data in enumerate(data_loader):
344
      
345
        for d in range(D_rounds):
346
          #Train Discriminator on Fake Data
347
          discriminator_1.zero_grad()
348
349
          h_g = generator_1.init_hidden()
350
351
          #Generating the noise and label data
352
          noise_sample = Variable(noise(len(sample_data),seq_length))
353
354
          #Use this line if generator outputs hidden states: dis_fake_data, (h_g_n,c_g_n) = generator.forward(noise_sample,h_g)
355
          dis_fake_data = generator_1.forward(noise_sample,h_g).detach()
356
357
          y_pred_fake = discriminator_1(dis_fake_data)
358
359
          loss_fake = loss_1(y_pred_fake,torch.zeros([len(sample_data),1]).cuda())
360
          loss_fake.backward()    
361
362
          #Train Discriminator on Real Data 
363
364
          real_data = Variable(sample_data.float()).cuda()    
365
          y_pred_real  = discriminator_1.forward(real_data)
366
367
          loss_real = loss_1(y_pred_real,torch.ones([len(sample_data),1]).cuda())
368
          loss_real.backward()
369
370
          d_optimizer_1.step() #Updating the weights based on the predictions for both real and fake calculations.
371
372
373
374
        #Train Generator  
375
        for g in range(G_rounds):
376
          generator_1.zero_grad()
377
          h_g = generator_1.init_hidden()
378
379
          noise_sample = Variable(noise(len(sample_data), seq_length))
380
381
382
          #Use this line if generator outputs hidden states: gen_fake_data, (h_g_n,c_g_n) = generator.forward(noise_sample,h_g)
383
          gen_fake_data = generator_1.forward(noise_sample,h_g)
384
          y_pred_gen = discriminator_1(gen_fake_data)
385
386
          error_gen = loss_1(y_pred_gen,torch.ones([len(sample_data),1]).cuda())
387
          error_gen.backward()
388
          g_optimizer_1.step()
389
390
      if (n_batch%100 == 0):
391
          print("\nERRORS FOR EPOCH: "+str(n)+"/"+str(num_epoch)+", batch_num: "+str(n_batch)+"/"+str(num_batches))
392
          print("Discriminator error: "+str(loss_fake+loss_real))
393
          print("Generator error: "+str(error_gen))
394
      if n_batch ==( num_batches - 1):
395
          G_losses.append(error_gen.item())
396
          D_losses.append((loss_real+loss_fake).item())
397
          
398
          #Saving the parameters of the model to file for each epoch
399
          torch.save(generator_1.state_dict(), path+'/generator_state_'+str(n)+'.pt')
400
          torch.save(discriminator_1.state_dict(),path+ '/discriminator_state_'+str(n)+'.pt')
401
402
        # Check how the generator is doing by saving G's output on fixed_noise
403
      
404
          with torch.no_grad():
405
              h_g = generator_1.init_hidden()
406
              fake = generator_1(noise(len(sample_data), seq_length),h_g).detach().cpu()
407
              generated_sample = torch.zeros(1,seq_length).cuda()
408
              testloader=torch.utils.data.DataLoader(sine_data_test, batch_size=sample_size, shuffle=True)
409
              
410
              
411
              for n_batch, sample_data in enumerate(testloader):
412
                noise_sample_test = noise(sample_size, seq_length)
413
                h_g = generator_1.init_hidden()
414
                generated_data = generator_1.forward(noise_sample_test,h_g).detach().squeeze()
415
                generated_sample = torch.cat((generated_sample,generated_data),dim = 0)
416
             
417
              
418
              # Getting the MMD Statistic for each Training Epoch
419
              generated_sample = generated_sample[1:][:]
420
              sigma = [pairwisedistances(sine_data_test[:].type(torch.DoubleTensor),generated_sample.type(torch.DoubleTensor).squeeze())] 
421
              mmd = MMDStatistic(len(sine_data_test[:]),generated_sample.size(0))
422
              mmd_eval = mmd(sine_data_test[:].type(torch.DoubleTensor),generated_sample.type(torch.DoubleTensor).squeeze(),sigma, ret_matrix=False)
423
              mmd_list.append(mmd_eval.item())
424
              
425
          
426
          series_list = np.append(series_list,fake[0].numpy().reshape((1,seq_length)),axis=0)
427
          
428
  #Dumping the errors and mmd evaluations for each training epoch.
429
  with open(path+'/generator_losses.txt', 'wb') as fp:
430
      pickle.dump(G_losses, fp)
431
  with open(path+'/discriminator_losses.txt', 'wb') as fp:
432
      pickle.dump(D_losses, fp)   
433
  with open(path+'/mmd_list.txt', 'wb') as fp:
434
      pickle.dump(mmd_list, fp)
435
  
436
  #Plotting the error graph
437
  plt.plot(G_losses,'-r',label='Generator Error')
438
  plt.plot(D_losses, '-b', label = 'Discriminator Error')
439
  plt.title('GAN Errors in Training')
440
  plt.legend()
441
  plt.savefig(path+'/GAN_errors.png')
442
  plt.close()
443
  
444
  
445
  #Plot a figure for each training epoch with the MMD value in the title
446
  i = 0
447
  while i < num_epoch:
448
    if i%3==0:
449
      fig, ax = plt.subplots(3,1,constrained_layout=True)
450
      fig.suptitle("Generated fake data")
451
    for j in range(0,3):
452
      ax[j].plot(series_list[i][:])
453
      ax[j].set_title('Epoch '+str(i)+ ', MMD: %.4f' % (mmd_list[i]))
454
      i = i+1
455
     
456
    plt.savefig(path+'/Training_Epoch_Samples_MMD_'+str(i)+'.png')
457
    plt.close(fig) 
458
    
459
460
  #Checking the diversity of the samples:
461
  generator_1.eval()
462
  h_g = generator_1.init_hidden()
463
  test_noise_sample = noise(sample_size, seq_length)
464
  gen_data= generator_1.forward(test_noise_sample,h_g).detach()
465
466
467
  plt.title("Generated Sine Waves")
468
  plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-b')
469
  plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-r')
470
  plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-g')
471
  plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-', color = 'orange')
472
  plt.savefig(path+'/Generated_Data_Sample1.png')
473
  plt.close()
474