[6bf179]: / ecg_gan / train.py

Download this file

127 lines (102 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, Adam
from .gan import Generator, Discriminator
from .dataset import ECGDataset, get_dataloader
from .config import Config
class Trainer:
def __init__(
self,
generator,
discriminator,
batch_size,
num_epochs,
label
):
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
self.netG = generator.to(self.device)
self.netD = discriminator.to(self.device)
self.optimizerD = Adam(self.netD.parameters(), lr=0.0002)
self.optimizerG = Adam(self.netG.parameters(), lr=0.0002)
self.criterion = nn.BCELoss()
self.batch_size = batch_size
self.signal_dim = [self.batch_size, 1, 187]
self.num_epochs = num_epochs
self.dataloader = get_dataloader(
label_name=label, batch_size=self.batch_size
)
self.fixed_noise = torch.randn(self.batch_size, 1, 187,
device=self.device)
self.g_errors = []
self.d_errors = []
def _one_epoch(self):
real_label = 1
fake_label = 0
for i, data in enumerate(self.dataloader, 0):
##### Update Discriminator: maximize log(D(x)) + log(1 - D(G(z))) #####
## train with real data
self.netD.zero_grad()
real_data = data[0].to(self.device)
# dim for noise
batch_size = real_data.size(0)
self.signal_dim[0] = batch_size
label = torch.full((batch_size,), real_label,
dtype=real_data.dtype, device=self.device)
output = self.netD(real_data)
output = output.view(-1)
errD_real = self.criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
## train with fake data
noise = torch.randn(self.signal_dim, device=self.device)
fake = self.netG(noise)
label.fill_(fake_label)
output = self.netD(fake.detach())
output = output.view(-1)
errD_fake = self.criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
self.optimizerD.step()
##### Update Generator: maximaze log(D(G(z)))
self.netG.zero_grad()
label.fill_(real_label)
output = self.netD(fake)
output = output.view(-1)
errG = self.criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
self.optimizerG.step()
return errD.item(), errG.item()
def run(self):
for epoch in range(self.num_epochs):
errD_, errG_ = self._one_epoch()
self.d_errors.append(errD_)
self.g_errors.append(errG_)
if epoch % 300 == 0:
print(f"Epoch: {epoch} | Loss_D: {errD_} | Loss_G: {errG_} | Time: {time.strftime('%H:%M:%S')}")
fake = self.netG(self.fixed_noise)
plt.plot(fake.detach().cpu().squeeze(1).numpy()[:].transpose())
plt.show()
torch.save(self.netG.state_dict(), f"generator.pth")
torch.save(self.netG.state_dict(), f"discriminator.pth")
if __name__ == '__main__':
config = Config()
g = Generator()
d = Discriminator()
trainer = Trainer(
generator=g,
discriminator=d,
batch_size=96,
num_epochs=3000,
label='Fusion of ventricular and normal'
)
trainer.run()