Diff of /train.py [000000] .. [5a2c8f]

Switch to unified view

a b/train.py
1
# -*- coding: utf-8 -*-
2
"""
3
Copy of GAN with Generator: LSTM, Discriminator: Convolutional NN with ECG Data
4
5
 Introduction
6
 ------------
7
    The aim of this script is to use a convolutional neural network with 
8
    a max pooling layer in the discrimiantor. 
9
    This was found to work well with the Physionet ECG data in a paper. 
10
    They used two convolutional NN so we will compare the difference between the 
11
    images generated using a single layer of CNN in the discriminator and 2 CNN layers 
12
    to see if this improves the quality of series generated.
13
14
"""
15
"""
16
Bringing in required dependencies as defined in the GitHub repo: 
17
    https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/permutation_test.pyx"""
18
from __future__ import division
19
20
import torch
21
from tqdm import tqdm
22
import numpy as np
23
from matplotlib import pyplot as plt
24
import seaborn as sns
25
26
from torchvision import transforms
27
from torch.autograd.variable import Variable
28
sns.set(rc={'figure.figsize':(11, 4)})
29
30
import datetime 
31
from datetime import date
32
today = date.today()
33
34
import random
35
import json as js
36
import pickle
37
import os
38
39
from data import ECGData, PD_to_Tensor
40
from Model import Generator, Discriminator 
41
42
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
43
44
if device == 'cuda:0':
45
    print('Using GPU : ')
46
    print(torch.cuda.get_device_name(device))
47
else :
48
    print('Using CPU')
49
50
51
"""#MMD Evaluation Metric Definition
52
Using MMD to determine the similarity between distributions
53
54
PDIST code comes from torch-two-sample utils code: 
55
    https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py
56
"""
57
58
def pdist(sample_1, sample_2, norm=2, eps=1e-5):
59
    r"""Compute the matrix of all squared pairwise distances.
60
    Arguments
61
    ---------
62
    sample_1 : torch.Tensor or Variable
63
        The first sample, should be of shape ``(n_1, d)``.
64
    sample_2 : torch.Tensor or Variable
65
        The second sample, should be of shape ``(n_2, d)``.
66
    norm : float
67
        The l_p norm to be used.
68
    Returns
69
    -------
70
    torch.Tensor or Variable
71
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
72
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""
73
    n_1, n_2 = sample_1.size(0), sample_2.size(0)
74
    norm = float(norm)
75
    
76
    if norm == 2.:
77
        norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
78
        norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
79
        norms = (norms_1.expand(n_1, n_2) +
80
                 norms_2.transpose(0, 1).expand(n_1, n_2))
81
        distances_squared = norms - 2 * sample_1.mm(sample_2.t())
82
        return torch.sqrt(eps + torch.abs(distances_squared))
83
    else:
84
        dim = sample_1.size(1)
85
        expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim)
86
        expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim)
87
        differences = torch.abs(expanded_1 - expanded_2) ** norm
88
        inner = torch.sum(differences, dim=2, keepdim=False)
89
        return (eps + inner) ** (1. / norm)
90
91
def permutation_test_mat(matrix,
92
                         n_1,  n_2,  n_permutations,
93
                          a00=1,  a11=1,  a01=0):
94
    """Compute the p-value of the following statistic (rejects when high)
95
        \sum_{i,j} a_{\pi(i), \pi(j)} matrix[i, j].
96
    """
97
    n = n_1 + n_2
98
    pi = np.zeros(n, dtype=np.int8)
99
    pi[n_1:] = 1
100
101
    larger = 0.
102
    count = 0
103
    
104
    for sample_n in range(1 + n_permutations):
105
        count = 0.
106
        for i in range(n):
107
            for j in range(i, n):
108
                mij = matrix[i, j] + matrix[j, i]
109
                if pi[i] == pi[j] == 0:
110
                    count += a00 * mij
111
                elif pi[i] == pi[j] == 1:
112
                    count += a11 * mij
113
                else:
114
                    count += a01 * mij
115
        if sample_n == 0:
116
            statistic = count
117
        elif statistic <= count:
118
            larger += 1
119
120
        np.random.shuffle(pi)
121
122
    return larger / n_permutations
123
124
"""Code from Torch-Two-Samples at https://torch-two-sample.readthedocs.io/en/latest/#"""
125
126
class MMDStatistic:
127
    r"""The *unbiased* MMD test of :cite:`gretton2012kernel`.
128
129
    The kernel used is equal to:
130
131
    .. math ::
132
        k(x, x') = \sum_{j=1}^k e^{-\alpha_j\|x - x'\|^2},
133
134
    for the :math:`\alpha_j` proved in :py:meth:`~.MMDStatistic.__call__`.
135
136
    Arguments
137
    ---------
138
    n_1: int
139
        The number of points in the first sample.
140
    n_2: int
141
        The number of points in the second sample."""
142
143
    def __init__(self, n_1, n_2):
144
        self.n_1 = n_1
145
        self.n_2 = n_2
146
147
        # The three constants used in the test.
148
        self.a00 = 1. / (n_1 * (n_1 - 1))
149
        self.a11 = 1. / (n_2 * (n_2 - 1))
150
        self.a01 = - 1. / (n_1 * n_2)
151
152
    def __call__(self, sample_1, sample_2, alphas, ret_matrix=False):
153
        r"""Evaluate the statistic.
154
155
        The kernel used is
156
157
        .. math::
158
159
            k(x, x') = \sum_{j=1}^k e^{-\alpha_j \|x - x'\|^2},
160
161
        for the provided ``alphas``.
162
163
        Arguments
164
        ---------
165
        sample_1: :class:`torch:torch.autograd.Variable`
166
            The first sample, of size ``(n_1, d)``.
167
        sample_2: variable of shape (n_2, d)
168
            The second sample, of size ``(n_2, d)``.
169
        alphas : list of :class:`float`
170
            The kernel parameters.
171
        ret_matrix: bool
172
            If set, the call with also return a second variable.
173
174
            This variable can be then used to compute a p-value using
175
            :py:meth:`~.MMDStatistic.pval`.
176
177
        Returns
178
        -------
179
        :class:`float`
180
            The test statistic.
181
        :class:`torch:torch.autograd.Variable`
182
            Returned only if ``ret_matrix`` was set to true."""
183
        sample_12 = torch.cat((sample_1, sample_2), 0)
184
        distances = pdist(sample_12, sample_12, norm=2)
185
186
        kernels = None
187
        for alpha in alphas:
188
            kernels_a = torch.exp(- alpha * distances ** 2)
189
            if kernels is None:
190
                kernels = kernels_a
191
            else:
192
                kernels = kernels + kernels_a
193
194
        k_1 = kernels[:self.n_1, :self.n_1]
195
        k_2 = kernels[self.n_1:, self.n_1:]
196
        k_12 = kernels[:self.n_1, self.n_1:]
197
198
        mmd = (2 * self.a01 * k_12.sum() +
199
               self.a00 * (k_1.sum() - torch.trace(k_1)) +
200
               self.a11 * (k_2.sum() - torch.trace(k_2)))
201
        if ret_matrix:
202
            return mmd, kernels
203
        else:
204
            return mmd
205
206
207
    def pval(self, distances, n_permutations=1000):
208
        r"""Compute a p-value using a permutation test.
209
210
        Arguments
211
        ---------
212
        matrix: :class:`torch:torch.autograd.Variable`
213
            The matrix computed using :py:meth:`~.MMDStatistic.__call__`.
214
        n_permutations: int
215
            The number of random draws from the permutation null.
216
217
        Returns
218
        -------
219
        float
220
            The estimated p-value."""
221
        if isinstance(distances, Variable):
222
            distances = distances.data
223
        return permutation_test_mat(distances.cpu().numpy(),
224
                                    self.n_1, self.n_2,
225
                                    n_permutations,
226
                                    a00=self.a00, a11=self.a11, a01=self.a01)
227
228
"""
229
230
This paper 
231
https://arxiv.org/pdf/1611.04488.pdf says that the most common way to 
232
calculate sigma is to use the median pairwise distances between the joint data.
233
234
"""
235
236
def pairwisedistances(X,Y,norm=2):
237
    dist = pdist(X,Y,norm)
238
    return np.median(dist.numpy())
239
240
241
""" 
242
243
Function for loading ECG Data 
244
245
"""
246
def GetECGData(source_file,class_id):
247
  compose = transforms.Compose(
248
        [PD_to_Tensor()
249
        ])
250
  return ECGData(source_file ,class_id = class_id, transform = compose)
251
252
"""
253
254
Creating the training set of sine/ECG signals
255
256
"""
257
258
#Taking normal ECG data for now
259
source_filename = './mitbih_train.csv'
260
ecg_data = GetECGData(source_file = source_filename,class_id = 0)
261
262
sample_size = 119 #batch size needed for Data Loader and the noise creator function.
263
264
# Create loader with data, so that we can iterate over it
265
266
data_loader = torch.utils.data.DataLoader(ecg_data, batch_size=sample_size, shuffle=True)
267
# Num batches
268
num_batches = len(data_loader)
269
print(num_batches)
270
271
"""Creating the Test Set"""
272
test_filename =  './mitbih_test.csv'
273
274
ecg_data_test = GetECGData(source_file = test_filename,class_id = 0)
275
276
data_loader_test = torch.utils.data.DataLoader(ecg_data_test[:18088], batch_size=sample_size, shuffle=True)
277
278
279
280
"""##Defining the noise creation function"""
281
282
def noise(batch_size, features):
283
  noise_vec = torch.randn(batch_size, features).to(device)
284
  return noise_vec
285
286
"""#Initialising Parameters"""
287
288
seq_length = ecg_data[0].size()[0] #Number of features
289
290
291
#Params for the generator
292
hidden_nodes_g = 50
293
layers = 2
294
tanh_layer = False
295
296
#No. of training rounds per epoch
297
D_rounds = 3
298
G_rounds = 1
299
num_epoch = 35
300
learning_rate = 0.0002
301
302
#Params for the Discriminator
303
minibatch_layer = 0
304
minibatch_normal_init_ = True
305
num_cvs = 2
306
cv1_out= 10
307
cv1_k = 3
308
cv1_s = 1
309
p1_k = 3
310
p1_s = 2
311
cv2_out = 10
312
cv2_k = 3
313
cv2_s = 1
314
p2_k = 3
315
p2_s = 2
316
317
"""# Evaluation of GAN with 2 CNN Layer in Discriminator
318
319
##Generator and Discriminator training phase
320
"""
321
322
minibatch_out = [0,3,5,8,10]
323
for minibatch_layer in minibatch_out:
324
  path = ".../your_path/Run_"+str(today.strftime("%d_%m_%Y"))+"_"+ str(datetime.datetime.now().time()).split('.')[0]
325
  os.mkdir(path)
326
  
327
  dict = {'data' : source_filename, 
328
          'sample_size' : sample_size, 
329
          'seq_length' : seq_length,
330
          'num_layers': layers, 
331
          'tanh_layer': tanh_layer,
332
          'hidden_dims_generator': hidden_nodes_g, 
333
          'minibatch_layer': minibatch_layer,
334
          'minibatch_normal_init_' : minibatch_normal_init_,
335
          'num_cvs':num_cvs,
336
          'cv1_out':cv1_out,
337
          'cv1_k':cv1_k,
338
          'cv1_s':cv1_s,
339
          'p1_k':p1_k,
340
          'p1_s':p1_s,
341
          'cv2_out':cv2_out,
342
          'cv2_k':cv2_k,
343
          'cv2_s':cv2_s,
344
          'p2_k':p2_k,
345
          'p2_s':p2_s,
346
          'num_epoch':num_epoch,
347
          'D_rounds': D_rounds,
348
          'G_rounds': G_rounds,  
349
          'learning_rate' : learning_rate
350
         }
351
352
  json = js.dumps(dict)
353
  f = open(path+"/settings.json","w")
354
  f.write(json)
355
  f.close()
356
357
  generator_1 = Generator(seq_length,sample_size,hidden_dim =  hidden_nodes_g, tanh_output = tanh_layer).to(device)
358
  discriminator_1 = Discriminator(seq_length, sample_size ,minibatch_normal_init = minibatch_normal_init_, minibatch = minibatch_layer,num_cv = num_cvs, cv1_out = cv1_out,cv1_k = cv1_k, cv1_s = cv1_s, p1_k = p1_k, p1_s = p1_s, cv2_out= cv2_out, cv2_k = cv2_k, cv2_s = cv2_s, p2_k = p2_k, p2_s = p2_s).to(device)
359
  #Loss function 
360
  loss_1 = torch.nn.BCELoss()
361
362
  generator_1.train()
363
  discriminator_1.train()
364
  
365
  d_optimizer_1 = torch.optim.Adam(discriminator_1.parameters(),lr = learning_rate)
366
  g_optimizer_1 = torch.optim.Adam(generator_1.parameters(),lr = learning_rate)
367
368
  G_losses = []
369
  D_losses = []
370
  mmd_list = []
371
  series_list = np.zeros((1,seq_length))
372
373
374
  for n in tqdm(range(num_epoch)):
375
     # for k in range(1):
376
377
      for n_batch, sample_data in enumerate(data_loader):
378
      ### TRAIN DISCRIMINATOR ON FAKE DATA
379
        for d in range(D_rounds):
380
          discriminator_1.zero_grad()
381
382
          h_g = generator_1.init_hidden()
383
384
          #Generating the noise and label data
385
          noise_sample = Variable(noise(len(sample_data),seq_length))
386
387
          #Use this line if generator outputs hidden states: dis_fake_data, (h_g_n,c_g_n) = generator.forward(noise_sample,h_g)
388
          dis_fake_data = generator_1.forward(noise_sample,h_g).detach()
389
390
          y_pred_fake = discriminator_1(dis_fake_data)
391
392
          loss_fake = loss_1(y_pred_fake,torch.zeros([len(sample_data),1]).to(device))
393
          loss_fake.backward()    
394
395
          #Train discriminator on real data   
396
          real_data = Variable(sample_data.float()).to(device)    
397
          y_pred_real  = discriminator_1.forward(real_data)
398
399
          loss_real = loss_1(y_pred_real,torch.ones([len(sample_data),1]).to(device))
400
          loss_real.backward()
401
402
          d_optimizer_1.step() #Updating the weights based on the predictions for both real and fake calculations.
403
404
405
406
        #Train Generator  
407
        for g in range(G_rounds):
408
          generator_1.zero_grad()
409
          h_g = generator_1.init_hidden()
410
411
          noise_sample = Variable(noise(len(sample_data), seq_length))
412
413
414
          #Use this line if generator outputs hidden states: gen_fake_data, (h_g_n,c_g_n) = generator.forward(noise_sample,h_g)
415
          gen_fake_data = generator_1.forward(noise_sample,h_g)
416
          y_pred_gen = discriminator_1(gen_fake_data)
417
418
          error_gen = loss_1(y_pred_gen,torch.ones([len(sample_data),1]).to(device))
419
          error_gen.backward()
420
          g_optimizer_1.step()         
421
    
422
      if n_batch ==( num_batches - 1):
423
          G_losses.append(error_gen.item())
424
          D_losses.append((loss_real+loss_fake).item())
425
          
426
          torch.save(generator_1.state_dict(), path+'/generator_state_'+str(n)+'.pt')
427
          torch.save(discriminator_1.state_dict(),path+ '/discriminator_state_'+str(n)+'.pt')
428
          
429
        # Check how the generator is doing by saving G's output on fixed_noise
430
          with torch.no_grad():
431
              h_g = generator_1.init_hidden()
432
              fake = generator_1(noise(len(sample_data), seq_length),h_g).detach().cpu()
433
              generated_sample = torch.zeros(1,seq_length).to(device)
434
              
435
              for iter in range(0,int(len(ecg_data_test[:18088])/sample_size)):
436
                noise_sample_test = noise(sample_size, seq_length)
437
                h_g = generator_1.init_hidden()
438
                generated_data = generator_1.forward(noise_sample_test,h_g).detach().squeeze()
439
                generated_sample = torch.cat((generated_sample,generated_data),dim = 0)
440
             
441
              
442
              # Getting the MMD Statistic for each Training Epoch
443
              generated_sample = generated_sample[1:][:]
444
              sigma = [pairwisedistances(ecg_data_test[:18088].type(torch.DoubleTensor),generated_sample.type(torch.DoubleTensor).squeeze())] 
445
              mmd = MMDStatistic(len(ecg_data_test[:18088]),generated_sample.size(0))
446
              mmd_eval = mmd(ecg_data_test[:18088].type(torch.DoubleTensor),generated_sample.type(torch.DoubleTensor).squeeze(),sigma, ret_matrix=False)
447
              mmd_list.append(mmd_eval.item())
448
              
449
          
450
              series_list = np.append(series_list,fake[0].numpy().reshape((1,seq_length)),axis=0)
451
          
452
  #Dumping the errors and mmd evaluations for each training epoch.
453
  with open(path+'/generator_losses.txt', 'wb') as fp:
454
      pickle.dump(G_losses, fp)
455
  with open(path+'/discriminator_losses.txt', 'wb') as fp:
456
      pickle.dump(D_losses, fp)   
457
  with open(path+'/mmd_list.txt', 'wb') as fp:
458
      pickle.dump(mmd_list, fp)
459
  
460
  #Plotting the error graph
461
  plt.plot(G_losses,'-r',label='Generator Error')
462
  plt.plot(D_losses, '-b', label = 'Discriminator Error')
463
  plt.title('GAN Errors in Training')
464
  plt.legend()
465
  plt.savefig(path+'/GAN_errors.png')
466
  plt.close() 
467
  
468
  #Plot a figure for each training epoch with the MMD value in the title
469
  i = 0
470
  while i < num_epoch:
471
    if i%3==0:
472
      fig, ax = plt.subplots(3,1,constrained_layout=True)
473
      fig.suptitle("Generated fake data")
474
    for j in range(0,3):
475
      ax[j].plot(series_list[i][:])
476
      ax[j].set_title('Epoch '+str(i)+ ', MMD: %.4f' % (mmd_list[i]))
477
      i = i+1
478
     
479
    plt.savefig(path+'/Training_Epoch_Samples_MMD_'+str(i)+'.png')
480
    plt.close(fig) 
481
  #Checking the diversity of the samples:
482
  generator_1.eval()
483
  h_g = generator_1.init_hidden()
484
  test_noise_sample = noise(sample_size, seq_length)
485
  gen_data= generator_1.forward(test_noise_sample,h_g).detach()
486
487
488
  plt.title("Generated ECG Waves")
489
  plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-b')
490
  plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-r')
491
  plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-g')
492
  plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-', color = 'orange')
493
  plt.savefig(path+'/Generated_Data_Sample1.png')
494
  plt.close()