--- a +++ b/ecg_gan/train.py @@ -0,0 +1,126 @@ +import os +import time + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader +from torch.optim import AdamW, Adam + +from .gan import Generator, Discriminator +from .dataset import ECGDataset, get_dataloader +from .config import Config + + +class Trainer: + def __init__( + self, + generator, + discriminator, + batch_size, + num_epochs, + label + ): + self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + self.netG = generator.to(self.device) + self.netD = discriminator.to(self.device) + + self.optimizerD = Adam(self.netD.parameters(), lr=0.0002) + self.optimizerG = Adam(self.netG.parameters(), lr=0.0002) + self.criterion = nn.BCELoss() + + self.batch_size = batch_size + self.signal_dim = [self.batch_size, 1, 187] + self.num_epochs = num_epochs + self.dataloader = get_dataloader( + label_name=label, batch_size=self.batch_size + ) + self.fixed_noise = torch.randn(self.batch_size, 1, 187, + device=self.device) + self.g_errors = [] + self.d_errors = [] + + def _one_epoch(self): + real_label = 1 + fake_label = 0 + + for i, data in enumerate(self.dataloader, 0): + ##### Update Discriminator: maximize log(D(x)) + log(1 - D(G(z))) ##### + ## train with real data + self.netD.zero_grad() + real_data = data[0].to(self.device) + # dim for noise + batch_size = real_data.size(0) + self.signal_dim[0] = batch_size + + label = torch.full((batch_size,), real_label, + dtype=real_data.dtype, device=self.device) + + output = self.netD(real_data) + output = output.view(-1) + + errD_real = self.criterion(output, label) + errD_real.backward() + D_x = output.mean().item() + + ## train with fake data + noise = torch.randn(self.signal_dim, device=self.device) + fake = self.netG(noise) + label.fill_(fake_label) + + output = self.netD(fake.detach()) + output = output.view(-1) + + errD_fake = self.criterion(output, label) + errD_fake.backward() + D_G_z1 = output.mean().item() + errD = errD_real + errD_fake + self.optimizerD.step() + + ##### Update Generator: maximaze log(D(G(z))) + self.netG.zero_grad() + label.fill_(real_label) + output = self.netD(fake) + output = output.view(-1) + + errG = self.criterion(output, label) + errG.backward() + D_G_z2 = output.mean().item() + self.optimizerG.step() + + return errD.item(), errG.item() + + def run(self): + for epoch in range(self.num_epochs): + errD_, errG_ = self._one_epoch() + self.d_errors.append(errD_) + self.g_errors.append(errG_) + if epoch % 300 == 0: + print(f"Epoch: {epoch} | Loss_D: {errD_} | Loss_G: {errG_} | Time: {time.strftime('%H:%M:%S')}") + + fake = self.netG(self.fixed_noise) + plt.plot(fake.detach().cpu().squeeze(1).numpy()[:].transpose()) + plt.show() + + torch.save(self.netG.state_dict(), f"generator.pth") + torch.save(self.netG.state_dict(), f"discriminator.pth") + + +if __name__ == '__main__': + config = Config() + g = Generator() + d = Discriminator() + + trainer = Trainer( + generator=g, + discriminator=d, + batch_size=96, + num_epochs=3000, + label='Fusion of ventricular and normal' + ) + trainer.run()