--- a +++ b/train.py @@ -0,0 +1,494 @@ +# -*- coding: utf-8 -*- +""" +Copy of GAN with Generator: LSTM, Discriminator: Convolutional NN with ECG Data + + Introduction + ------------ + The aim of this script is to use a convolutional neural network with + a max pooling layer in the discrimiantor. + This was found to work well with the Physionet ECG data in a paper. + They used two convolutional NN so we will compare the difference between the + images generated using a single layer of CNN in the discriminator and 2 CNN layers + to see if this improves the quality of series generated. + +""" +""" +Bringing in required dependencies as defined in the GitHub repo: + https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/permutation_test.pyx""" +from __future__ import division + +import torch +from tqdm import tqdm +import numpy as np +from matplotlib import pyplot as plt +import seaborn as sns + +from torchvision import transforms +from torch.autograd.variable import Variable +sns.set(rc={'figure.figsize':(11, 4)}) + +import datetime +from datetime import date +today = date.today() + +import random +import json as js +import pickle +import os + +from data import ECGData, PD_to_Tensor +from Model import Generator, Discriminator + +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + +if device == 'cuda:0': + print('Using GPU : ') + print(torch.cuda.get_device_name(device)) +else : + print('Using CPU') + + +"""#MMD Evaluation Metric Definition +Using MMD to determine the similarity between distributions + +PDIST code comes from torch-two-sample utils code: + https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py +""" + +def pdist(sample_1, sample_2, norm=2, eps=1e-5): + r"""Compute the matrix of all squared pairwise distances. + Arguments + --------- + sample_1 : torch.Tensor or Variable + The first sample, should be of shape ``(n_1, d)``. + sample_2 : torch.Tensor or Variable + The second sample, should be of shape ``(n_2, d)``. + norm : float + The l_p norm to be used. + Returns + ------- + torch.Tensor or Variable + Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to + ``|| sample_1[i, :] - sample_2[j, :] ||_p``.""" + n_1, n_2 = sample_1.size(0), sample_2.size(0) + norm = float(norm) + + if norm == 2.: + norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True) + norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True) + norms = (norms_1.expand(n_1, n_2) + + norms_2.transpose(0, 1).expand(n_1, n_2)) + distances_squared = norms - 2 * sample_1.mm(sample_2.t()) + return torch.sqrt(eps + torch.abs(distances_squared)) + else: + dim = sample_1.size(1) + expanded_1 = sample_1.unsqueeze(1).expand(n_1, n_2, dim) + expanded_2 = sample_2.unsqueeze(0).expand(n_1, n_2, dim) + differences = torch.abs(expanded_1 - expanded_2) ** norm + inner = torch.sum(differences, dim=2, keepdim=False) + return (eps + inner) ** (1. / norm) + +def permutation_test_mat(matrix, + n_1, n_2, n_permutations, + a00=1, a11=1, a01=0): + """Compute the p-value of the following statistic (rejects when high) + \sum_{i,j} a_{\pi(i), \pi(j)} matrix[i, j]. + """ + n = n_1 + n_2 + pi = np.zeros(n, dtype=np.int8) + pi[n_1:] = 1 + + larger = 0. + count = 0 + + for sample_n in range(1 + n_permutations): + count = 0. + for i in range(n): + for j in range(i, n): + mij = matrix[i, j] + matrix[j, i] + if pi[i] == pi[j] == 0: + count += a00 * mij + elif pi[i] == pi[j] == 1: + count += a11 * mij + else: + count += a01 * mij + if sample_n == 0: + statistic = count + elif statistic <= count: + larger += 1 + + np.random.shuffle(pi) + + return larger / n_permutations + +"""Code from Torch-Two-Samples at https://torch-two-sample.readthedocs.io/en/latest/#""" + +class MMDStatistic: + r"""The *unbiased* MMD test of :cite:`gretton2012kernel`. + + The kernel used is equal to: + + .. math :: + k(x, x') = \sum_{j=1}^k e^{-\alpha_j\|x - x'\|^2}, + + for the :math:`\alpha_j` proved in :py:meth:`~.MMDStatistic.__call__`. + + Arguments + --------- + n_1: int + The number of points in the first sample. + n_2: int + The number of points in the second sample.""" + + def __init__(self, n_1, n_2): + self.n_1 = n_1 + self.n_2 = n_2 + + # The three constants used in the test. + self.a00 = 1. / (n_1 * (n_1 - 1)) + self.a11 = 1. / (n_2 * (n_2 - 1)) + self.a01 = - 1. / (n_1 * n_2) + + def __call__(self, sample_1, sample_2, alphas, ret_matrix=False): + r"""Evaluate the statistic. + + The kernel used is + + .. math:: + + k(x, x') = \sum_{j=1}^k e^{-\alpha_j \|x - x'\|^2}, + + for the provided ``alphas``. + + Arguments + --------- + sample_1: :class:`torch:torch.autograd.Variable` + The first sample, of size ``(n_1, d)``. + sample_2: variable of shape (n_2, d) + The second sample, of size ``(n_2, d)``. + alphas : list of :class:`float` + The kernel parameters. + ret_matrix: bool + If set, the call with also return a second variable. + + This variable can be then used to compute a p-value using + :py:meth:`~.MMDStatistic.pval`. + + Returns + ------- + :class:`float` + The test statistic. + :class:`torch:torch.autograd.Variable` + Returned only if ``ret_matrix`` was set to true.""" + sample_12 = torch.cat((sample_1, sample_2), 0) + distances = pdist(sample_12, sample_12, norm=2) + + kernels = None + for alpha in alphas: + kernels_a = torch.exp(- alpha * distances ** 2) + if kernels is None: + kernels = kernels_a + else: + kernels = kernels + kernels_a + + k_1 = kernels[:self.n_1, :self.n_1] + k_2 = kernels[self.n_1:, self.n_1:] + k_12 = kernels[:self.n_1, self.n_1:] + + mmd = (2 * self.a01 * k_12.sum() + + self.a00 * (k_1.sum() - torch.trace(k_1)) + + self.a11 * (k_2.sum() - torch.trace(k_2))) + if ret_matrix: + return mmd, kernels + else: + return mmd + + + def pval(self, distances, n_permutations=1000): + r"""Compute a p-value using a permutation test. + + Arguments + --------- + matrix: :class:`torch:torch.autograd.Variable` + The matrix computed using :py:meth:`~.MMDStatistic.__call__`. + n_permutations: int + The number of random draws from the permutation null. + + Returns + ------- + float + The estimated p-value.""" + if isinstance(distances, Variable): + distances = distances.data + return permutation_test_mat(distances.cpu().numpy(), + self.n_1, self.n_2, + n_permutations, + a00=self.a00, a11=self.a11, a01=self.a01) + +""" + +This paper +https://arxiv.org/pdf/1611.04488.pdf says that the most common way to +calculate sigma is to use the median pairwise distances between the joint data. + +""" + +def pairwisedistances(X,Y,norm=2): + dist = pdist(X,Y,norm) + return np.median(dist.numpy()) + + +""" + +Function for loading ECG Data + +""" +def GetECGData(source_file,class_id): + compose = transforms.Compose( + [PD_to_Tensor() + ]) + return ECGData(source_file ,class_id = class_id, transform = compose) + +""" + +Creating the training set of sine/ECG signals + +""" + +#Taking normal ECG data for now +source_filename = './mitbih_train.csv' +ecg_data = GetECGData(source_file = source_filename,class_id = 0) + +sample_size = 119 #batch size needed for Data Loader and the noise creator function. + +# Create loader with data, so that we can iterate over it + +data_loader = torch.utils.data.DataLoader(ecg_data, batch_size=sample_size, shuffle=True) +# Num batches +num_batches = len(data_loader) +print(num_batches) + +"""Creating the Test Set""" +test_filename = './mitbih_test.csv' + +ecg_data_test = GetECGData(source_file = test_filename,class_id = 0) + +data_loader_test = torch.utils.data.DataLoader(ecg_data_test[:18088], batch_size=sample_size, shuffle=True) + + + +"""##Defining the noise creation function""" + +def noise(batch_size, features): + noise_vec = torch.randn(batch_size, features).to(device) + return noise_vec + +"""#Initialising Parameters""" + +seq_length = ecg_data[0].size()[0] #Number of features + + +#Params for the generator +hidden_nodes_g = 50 +layers = 2 +tanh_layer = False + +#No. of training rounds per epoch +D_rounds = 3 +G_rounds = 1 +num_epoch = 35 +learning_rate = 0.0002 + +#Params for the Discriminator +minibatch_layer = 0 +minibatch_normal_init_ = True +num_cvs = 2 +cv1_out= 10 +cv1_k = 3 +cv1_s = 1 +p1_k = 3 +p1_s = 2 +cv2_out = 10 +cv2_k = 3 +cv2_s = 1 +p2_k = 3 +p2_s = 2 + +"""# Evaluation of GAN with 2 CNN Layer in Discriminator + +##Generator and Discriminator training phase +""" + +minibatch_out = [0,3,5,8,10] +for minibatch_layer in minibatch_out: + path = ".../your_path/Run_"+str(today.strftime("%d_%m_%Y"))+"_"+ str(datetime.datetime.now().time()).split('.')[0] + os.mkdir(path) + + dict = {'data' : source_filename, + 'sample_size' : sample_size, + 'seq_length' : seq_length, + 'num_layers': layers, + 'tanh_layer': tanh_layer, + 'hidden_dims_generator': hidden_nodes_g, + 'minibatch_layer': minibatch_layer, + 'minibatch_normal_init_' : minibatch_normal_init_, + 'num_cvs':num_cvs, + 'cv1_out':cv1_out, + 'cv1_k':cv1_k, + 'cv1_s':cv1_s, + 'p1_k':p1_k, + 'p1_s':p1_s, + 'cv2_out':cv2_out, + 'cv2_k':cv2_k, + 'cv2_s':cv2_s, + 'p2_k':p2_k, + 'p2_s':p2_s, + 'num_epoch':num_epoch, + 'D_rounds': D_rounds, + 'G_rounds': G_rounds, + 'learning_rate' : learning_rate + } + + json = js.dumps(dict) + f = open(path+"/settings.json","w") + f.write(json) + f.close() + + generator_1 = Generator(seq_length,sample_size,hidden_dim = hidden_nodes_g, tanh_output = tanh_layer).to(device) + discriminator_1 = Discriminator(seq_length, sample_size ,minibatch_normal_init = minibatch_normal_init_, minibatch = minibatch_layer,num_cv = num_cvs, cv1_out = cv1_out,cv1_k = cv1_k, cv1_s = cv1_s, p1_k = p1_k, p1_s = p1_s, cv2_out= cv2_out, cv2_k = cv2_k, cv2_s = cv2_s, p2_k = p2_k, p2_s = p2_s).to(device) + #Loss function + loss_1 = torch.nn.BCELoss() + + generator_1.train() + discriminator_1.train() + + d_optimizer_1 = torch.optim.Adam(discriminator_1.parameters(),lr = learning_rate) + g_optimizer_1 = torch.optim.Adam(generator_1.parameters(),lr = learning_rate) + + G_losses = [] + D_losses = [] + mmd_list = [] + series_list = np.zeros((1,seq_length)) + + + for n in tqdm(range(num_epoch)): + # for k in range(1): + + for n_batch, sample_data in enumerate(data_loader): + ### TRAIN DISCRIMINATOR ON FAKE DATA + for d in range(D_rounds): + discriminator_1.zero_grad() + + h_g = generator_1.init_hidden() + + #Generating the noise and label data + noise_sample = Variable(noise(len(sample_data),seq_length)) + + #Use this line if generator outputs hidden states: dis_fake_data, (h_g_n,c_g_n) = generator.forward(noise_sample,h_g) + dis_fake_data = generator_1.forward(noise_sample,h_g).detach() + + y_pred_fake = discriminator_1(dis_fake_data) + + loss_fake = loss_1(y_pred_fake,torch.zeros([len(sample_data),1]).to(device)) + loss_fake.backward() + + #Train discriminator on real data + real_data = Variable(sample_data.float()).to(device) + y_pred_real = discriminator_1.forward(real_data) + + loss_real = loss_1(y_pred_real,torch.ones([len(sample_data),1]).to(device)) + loss_real.backward() + + d_optimizer_1.step() #Updating the weights based on the predictions for both real and fake calculations. + + + + #Train Generator + for g in range(G_rounds): + generator_1.zero_grad() + h_g = generator_1.init_hidden() + + noise_sample = Variable(noise(len(sample_data), seq_length)) + + + #Use this line if generator outputs hidden states: gen_fake_data, (h_g_n,c_g_n) = generator.forward(noise_sample,h_g) + gen_fake_data = generator_1.forward(noise_sample,h_g) + y_pred_gen = discriminator_1(gen_fake_data) + + error_gen = loss_1(y_pred_gen,torch.ones([len(sample_data),1]).to(device)) + error_gen.backward() + g_optimizer_1.step() + + if n_batch ==( num_batches - 1): + G_losses.append(error_gen.item()) + D_losses.append((loss_real+loss_fake).item()) + + torch.save(generator_1.state_dict(), path+'/generator_state_'+str(n)+'.pt') + torch.save(discriminator_1.state_dict(),path+ '/discriminator_state_'+str(n)+'.pt') + + # Check how the generator is doing by saving G's output on fixed_noise + with torch.no_grad(): + h_g = generator_1.init_hidden() + fake = generator_1(noise(len(sample_data), seq_length),h_g).detach().cpu() + generated_sample = torch.zeros(1,seq_length).to(device) + + for iter in range(0,int(len(ecg_data_test[:18088])/sample_size)): + noise_sample_test = noise(sample_size, seq_length) + h_g = generator_1.init_hidden() + generated_data = generator_1.forward(noise_sample_test,h_g).detach().squeeze() + generated_sample = torch.cat((generated_sample,generated_data),dim = 0) + + + # Getting the MMD Statistic for each Training Epoch + generated_sample = generated_sample[1:][:] + sigma = [pairwisedistances(ecg_data_test[:18088].type(torch.DoubleTensor),generated_sample.type(torch.DoubleTensor).squeeze())] + mmd = MMDStatistic(len(ecg_data_test[:18088]),generated_sample.size(0)) + mmd_eval = mmd(ecg_data_test[:18088].type(torch.DoubleTensor),generated_sample.type(torch.DoubleTensor).squeeze(),sigma, ret_matrix=False) + mmd_list.append(mmd_eval.item()) + + + series_list = np.append(series_list,fake[0].numpy().reshape((1,seq_length)),axis=0) + + #Dumping the errors and mmd evaluations for each training epoch. + with open(path+'/generator_losses.txt', 'wb') as fp: + pickle.dump(G_losses, fp) + with open(path+'/discriminator_losses.txt', 'wb') as fp: + pickle.dump(D_losses, fp) + with open(path+'/mmd_list.txt', 'wb') as fp: + pickle.dump(mmd_list, fp) + + #Plotting the error graph + plt.plot(G_losses,'-r',label='Generator Error') + plt.plot(D_losses, '-b', label = 'Discriminator Error') + plt.title('GAN Errors in Training') + plt.legend() + plt.savefig(path+'/GAN_errors.png') + plt.close() + + #Plot a figure for each training epoch with the MMD value in the title + i = 0 + while i < num_epoch: + if i%3==0: + fig, ax = plt.subplots(3,1,constrained_layout=True) + fig.suptitle("Generated fake data") + for j in range(0,3): + ax[j].plot(series_list[i][:]) + ax[j].set_title('Epoch '+str(i)+ ', MMD: %.4f' % (mmd_list[i])) + i = i+1 + + plt.savefig(path+'/Training_Epoch_Samples_MMD_'+str(i)+'.png') + plt.close(fig) + #Checking the diversity of the samples: + generator_1.eval() + h_g = generator_1.init_hidden() + test_noise_sample = noise(sample_size, seq_length) + gen_data= generator_1.forward(test_noise_sample,h_g).detach() + + + plt.title("Generated ECG Waves") + plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-b') + plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-r') + plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-g') + plt.plot(gen_data[random.randint(0,sample_size-1)].tolist(),'-', color = 'orange') + plt.savefig(path+'/Generated_Data_Sample1.png') + plt.close() \ No newline at end of file