--- a
+++ b/ecg_classification/train.py
@@ -0,0 +1,131 @@
+import os
+import time
+import random
+
+import numpy as np 
+import pandas as pd 
+import matplotlib.pyplot as plt
+import matplotlib.colors as mcolors
+import seaborn as sns
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.optim import AdamW, Adam
+from torch.optim.lr_scheduler import (CosineAnnealingLR,
+                                      CosineAnnealingWarmRestarts,
+                                      StepLR,
+                                      ExponentialLR)
+
+from .meter import Meter
+from .dataset import ECGDataset
+from .models import *
+from .config import Config, seed_everything
+
+
+class Trainer:
+    def __init__(self, net, lr, batch_size, num_epochs):
+        self.net = net.to(config.device)
+        self.num_epochs = num_epochs
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = AdamW(self.net.parameters(), lr=lr)
+        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs, eta_min=5e-6)
+        self.best_loss = float('inf')
+        self.phases = ['train', 'val']
+        self.dataloaders = {
+            phase: get_dataloader(phase, batch_size) for phase in self.phases
+        }
+        self.train_df_logs = pd.DataFrame()
+        self.val_df_logs = pd.DataFrame()
+    
+    def _train_epoch(self, phase):
+        print(f"{phase} mode | time: {time.strftime('%H:%M:%S')}")
+        
+        self.net.train() if phase == 'train' else self.net.eval()
+        meter = Meter()
+        meter.init_metrics()
+        
+        for i, (data, target) in enumerate(self.dataloaders[phase]):
+            data = data.to(config.device)
+            target = target.to(config.device)
+            
+            output = self.net(data)
+            loss = self.criterion(output, target)
+                        
+            if phase == 'train':
+                self.optimizer.zero_grad()
+                loss.backward()
+                self.optimizer.step()
+            
+            meter.update(output, target, loss.item())
+        
+        metrics = meter.get_metrics()
+        metrics = {k:v / i for k, v in metrics.items()}
+        df_logs = pd.DataFrame([metrics])
+        confusion_matrix = meter.get_confusion_matrix()
+        
+        if phase == 'train':
+            self.train_df_logs = pd.concat([self.train_df_logs, df_logs], axis=0)
+        else:
+            self.val_df_logs = pd.concat([self.val_df_logs, df_logs], axis=0)
+        
+        # show logs
+        print('{}: {}, {}: {}, {}: {}, {}: {}, {}: {}'
+              .format(*(x for kv in metrics.items() for x in kv))
+             )
+        fig, ax = plt.subplots(figsize=(5, 5))
+        cm_ = ax.imshow(confusion_matrix, cmap='hot')
+        ax.set_title('Confusion matrix', fontsize=15)
+        ax.set_xlabel('Actual', fontsize=13)
+        ax.set_ylabel('Predicted', fontsize=13)
+        plt.colorbar(cm_)
+        plt.show()
+        
+        return loss
+    
+    def run(self):
+        for epoch in range(self.num_epochs):
+            self._train_epoch(phase='train')
+            with torch.no_grad():
+                val_loss = self._train_epoch(phase='val')
+                self.scheduler.step()
+            
+            if val_loss < self.best_loss:
+                self.best_loss = val_loss
+                print('\nNew checkpoint\n')
+                self.best_loss = val_loss
+                torch.save(self.net.state_dict(), f"best_model_epoc{epoch}.pth")
+
+              
+if __name__ == '__main__':
+    # init config and set random seed
+    config = Config()
+    seed_everything(config.seed)
+     
+    # init model
+    #model = RNNAttentionModel(1, 64, 'lstm', False)
+    #model = RNNModel(1, 64, 'lstm', True)
+    model = CNN(num_classes=5, hid_size=128)  
+              
+    # start train
+    trainer = Trainer(net=model, lr=1e-3, batch_size=96, num_epochs=30)
+    trainer.run()  
+              
+    # write logs
+    train_logs = trainer.train_df_logs
+    train_logs.columns = ["train_"+ colname for colname in train_logs.columns]
+    val_logs = trainer.val_df_logs
+    val_logs.columns = ["val_"+ colname for colname in val_logs.columns]
+
+    logs = pd.concat([train_logs,val_logs], axis=1)
+    logs.reset_index(drop=True, inplace=True)
+    logs = logs.loc[:, [
+        'train_loss', 'val_loss', 
+        'train_accuracy', 'val_accuracy', 
+        'train_f1', 'val_f1',
+        'train_precision', 'val_precision',
+        'train_recall', 'val_recall']
+                                     ]
+    print(logs.head())
+    logs.to_csv('cnn.csv', index=False)
+