import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, Adam
from .gan import Generator, Discriminator
from .dataset import ECGDataset, get_dataloader
from .config import Config
class Trainer:
def __init__(
self,
generator,
discriminator,
batch_size,
num_epochs,
label
):
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
self.netG = generator.to(self.device)
self.netD = discriminator.to(self.device)
self.optimizerD = Adam(self.netD.parameters(), lr=0.0002)
self.optimizerG = Adam(self.netG.parameters(), lr=0.0002)
self.criterion = nn.BCELoss()
self.batch_size = batch_size
self.signal_dim = [self.batch_size, 1, 187]
self.num_epochs = num_epochs
self.dataloader = get_dataloader(
label_name=label, batch_size=self.batch_size
)
self.fixed_noise = torch.randn(self.batch_size, 1, 187,
device=self.device)
self.g_errors = []
self.d_errors = []
def _one_epoch(self):
real_label = 1
fake_label = 0
for i, data in enumerate(self.dataloader, 0):
##### Update Discriminator: maximize log(D(x)) + log(1 - D(G(z))) #####
## train with real data
self.netD.zero_grad()
real_data = data[0].to(self.device)
# dim for noise
batch_size = real_data.size(0)
self.signal_dim[0] = batch_size
label = torch.full((batch_size,), real_label,
dtype=real_data.dtype, device=self.device)
output = self.netD(real_data)
output = output.view(-1)
errD_real = self.criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
## train with fake data
noise = torch.randn(self.signal_dim, device=self.device)
fake = self.netG(noise)
label.fill_(fake_label)
output = self.netD(fake.detach())
output = output.view(-1)
errD_fake = self.criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
self.optimizerD.step()
##### Update Generator: maximaze log(D(G(z)))
self.netG.zero_grad()
label.fill_(real_label)
output = self.netD(fake)
output = output.view(-1)
errG = self.criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
self.optimizerG.step()
return errD.item(), errG.item()
def run(self):
for epoch in range(self.num_epochs):
errD_, errG_ = self._one_epoch()
self.d_errors.append(errD_)
self.g_errors.append(errG_)
if epoch % 300 == 0:
print(f"Epoch: {epoch} | Loss_D: {errD_} | Loss_G: {errG_} | Time: {time.strftime('%H:%M:%S')}")
fake = self.netG(self.fixed_noise)
plt.plot(fake.detach().cpu().squeeze(1).numpy()[:].transpose())
plt.show()
torch.save(self.netG.state_dict(), f"generator.pth")
torch.save(self.netG.state_dict(), f"discriminator.pth")
if __name__ == '__main__':
config = Config()
g = Generator()
d = Discriminator()
trainer = Trainer(
generator=g,
discriminator=d,
batch_size=96,
num_epochs=3000,
label='Fusion of ventricular and normal'
)
trainer.run()