Diff of /ecg_gan/train.py [000000] .. [6bf179]

Switch to unified view

a b/ecg_gan/train.py
1
import os
2
import time
3
4
import numpy as np
5
import pandas as pd 
6
import matplotlib.pyplot as plt
7
import seaborn as sns
8
9
import torch
10
import torch.nn as nn
11
import torch.nn.functional as F
12
from torch.utils.data import Dataset, DataLoader
13
from torch.optim import AdamW, Adam
14
15
from .gan import Generator, Discriminator
16
from .dataset import ECGDataset, get_dataloader
17
from .config import Config
18
19
20
class Trainer:
21
    def __init__(
22
        self,
23
        generator,
24
        discriminator,
25
        batch_size,
26
        num_epochs,
27
        label
28
    ):
29
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
30
        self.netG = generator.to(self.device)
31
        self.netD = discriminator.to(self.device)
32
        
33
        self.optimizerD = Adam(self.netD.parameters(), lr=0.0002)
34
        self.optimizerG = Adam(self.netG.parameters(), lr=0.0002)
35
        self.criterion = nn.BCELoss()
36
        
37
        self.batch_size = batch_size
38
        self.signal_dim = [self.batch_size, 1, 187]
39
        self.num_epochs = num_epochs
40
        self.dataloader = get_dataloader(
41
            label_name=label, batch_size=self.batch_size
42
        )
43
        self.fixed_noise = torch.randn(self.batch_size, 1, 187,
44
                                       device=self.device)
45
        self.g_errors = []
46
        self.d_errors = []
47
        
48
    def _one_epoch(self):
49
        real_label = 1
50
        fake_label = 0
51
        
52
        for i, data in enumerate(self.dataloader, 0):
53
            ##### Update Discriminator: maximize log(D(x)) + log(1 - D(G(z))) #####
54
            ## train with real data
55
            self.netD.zero_grad()
56
            real_data = data[0].to(self.device)
57
            # dim for noise
58
            batch_size = real_data.size(0)
59
            self.signal_dim[0] = batch_size
60
            
61
            label = torch.full((batch_size,), real_label,
62
                           dtype=real_data.dtype, device=self.device)
63
            
64
            output = self.netD(real_data)
65
            output = output.view(-1)
66
       
67
            errD_real = self.criterion(output, label)
68
            errD_real.backward()
69
            D_x = output.mean().item()
70
            
71
            ## train with fake data
72
            noise = torch.randn(self.signal_dim, device=self.device)
73
            fake = self.netG(noise)
74
            label.fill_(fake_label)
75
            
76
            output = self.netD(fake.detach())
77
            output = output.view(-1)
78
            
79
            errD_fake = self.criterion(output, label)
80
            errD_fake.backward()
81
            D_G_z1 = output.mean().item()
82
            errD = errD_real + errD_fake 
83
            self.optimizerD.step()
84
            
85
            ##### Update Generator: maximaze log(D(G(z)))  
86
            self.netG.zero_grad()
87
            label.fill_(real_label) 
88
            output = self.netD(fake)
89
            output = output.view(-1)
90
            
91
            errG = self.criterion(output, label)
92
            errG.backward()
93
            D_G_z2 = output.mean().item()
94
            self.optimizerG.step()
95
            
96
        return errD.item(), errG.item()
97
        
98
    def run(self):
99
        for epoch in range(self.num_epochs):
100
            errD_, errG_ = self._one_epoch()
101
            self.d_errors.append(errD_)
102
            self.g_errors.append(errG_)
103
            if epoch % 300 == 0:
104
                print(f"Epoch: {epoch} | Loss_D: {errD_} | Loss_G: {errG_} | Time: {time.strftime('%H:%M:%S')}")
105
   
106
                fake = self.netG(self.fixed_noise)
107
                plt.plot(fake.detach().cpu().squeeze(1).numpy()[:].transpose())
108
                plt.show()
109
            
110
        torch.save(self.netG.state_dict(), f"generator.pth")
111
        torch.save(self.netG.state_dict(), f"discriminator.pth")
112
                      
113
               
114
if __name__ == '__main__':
115
    config = Config()
116
    g = Generator()
117
    d = Discriminator()
118
                      
119
    trainer = Trainer(
120
      generator=g,
121
      discriminator=d,
122
      batch_size=96,
123
      num_epochs=3000,
124
      label='Fusion of ventricular and normal'
125
  )
126
    trainer.run()