--- a
+++ b/ecg_gan/train.py
@@ -0,0 +1,126 @@
+import os
+import time
+
+import numpy as np
+import pandas as pd 
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import Dataset, DataLoader
+from torch.optim import AdamW, Adam
+
+from .gan import Generator, Discriminator
+from .dataset import ECGDataset, get_dataloader
+from .config import Config
+
+
+class Trainer:
+    def __init__(
+        self,
+        generator,
+        discriminator,
+        batch_size,
+        num_epochs,
+        label
+    ):
+        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+        self.netG = generator.to(self.device)
+        self.netD = discriminator.to(self.device)
+        
+        self.optimizerD = Adam(self.netD.parameters(), lr=0.0002)
+        self.optimizerG = Adam(self.netG.parameters(), lr=0.0002)
+        self.criterion = nn.BCELoss()
+        
+        self.batch_size = batch_size
+        self.signal_dim = [self.batch_size, 1, 187]
+        self.num_epochs = num_epochs
+        self.dataloader = get_dataloader(
+            label_name=label, batch_size=self.batch_size
+        )
+        self.fixed_noise = torch.randn(self.batch_size, 1, 187,
+                                       device=self.device)
+        self.g_errors = []
+        self.d_errors = []
+        
+    def _one_epoch(self):
+        real_label = 1
+        fake_label = 0
+        
+        for i, data in enumerate(self.dataloader, 0):
+            ##### Update Discriminator: maximize log(D(x)) + log(1 - D(G(z))) #####
+            ## train with real data
+            self.netD.zero_grad()
+            real_data = data[0].to(self.device)
+            # dim for noise
+            batch_size = real_data.size(0)
+            self.signal_dim[0] = batch_size
+            
+            label = torch.full((batch_size,), real_label,
+                           dtype=real_data.dtype, device=self.device)
+            
+            output = self.netD(real_data)
+            output = output.view(-1)
+       
+            errD_real = self.criterion(output, label)
+            errD_real.backward()
+            D_x = output.mean().item()
+            
+            ## train with fake data
+            noise = torch.randn(self.signal_dim, device=self.device)
+            fake = self.netG(noise)
+            label.fill_(fake_label)
+            
+            output = self.netD(fake.detach())
+            output = output.view(-1)
+            
+            errD_fake = self.criterion(output, label)
+            errD_fake.backward()
+            D_G_z1 = output.mean().item()
+            errD = errD_real + errD_fake 
+            self.optimizerD.step()
+            
+            ##### Update Generator: maximaze log(D(G(z)))  
+            self.netG.zero_grad()
+            label.fill_(real_label) 
+            output = self.netD(fake)
+            output = output.view(-1)
+            
+            errG = self.criterion(output, label)
+            errG.backward()
+            D_G_z2 = output.mean().item()
+            self.optimizerG.step()
+            
+        return errD.item(), errG.item()
+        
+    def run(self):
+        for epoch in range(self.num_epochs):
+            errD_, errG_ = self._one_epoch()
+            self.d_errors.append(errD_)
+            self.g_errors.append(errG_)
+            if epoch % 300 == 0:
+                print(f"Epoch: {epoch} | Loss_D: {errD_} | Loss_G: {errG_} | Time: {time.strftime('%H:%M:%S')}")
+   
+                fake = self.netG(self.fixed_noise)
+                plt.plot(fake.detach().cpu().squeeze(1).numpy()[:].transpose())
+                plt.show()
+            
+        torch.save(self.netG.state_dict(), f"generator.pth")
+        torch.save(self.netG.state_dict(), f"discriminator.pth")
+                      
+               
+if __name__ == '__main__':
+    config = Config()
+    g = Generator()
+    d = Discriminator()
+                      
+    trainer = Trainer(
+      generator=g,
+      discriminator=d,
+      batch_size=96,
+      num_epochs=3000,
+      label='Fusion of ventricular and normal'
+  )
+    trainer.run()