Switch to unified view

a b/ecg_classification/train.py
1
import os
2
import time
3
import random
4
5
import numpy as np 
6
import pandas as pd 
7
import matplotlib.pyplot as plt
8
import matplotlib.colors as mcolors
9
import seaborn as sns
10
11
import torch
12
import torch.nn as nn
13
import torch.nn.functional as F
14
from torch.optim import AdamW, Adam
15
from torch.optim.lr_scheduler import (CosineAnnealingLR,
16
                                      CosineAnnealingWarmRestarts,
17
                                      StepLR,
18
                                      ExponentialLR)
19
20
from .meter import Meter
21
from .dataset import ECGDataset
22
from .models import *
23
from .config import Config, seed_everything
24
25
26
class Trainer:
27
    def __init__(self, net, lr, batch_size, num_epochs):
28
        self.net = net.to(config.device)
29
        self.num_epochs = num_epochs
30
        self.criterion = nn.CrossEntropyLoss()
31
        self.optimizer = AdamW(self.net.parameters(), lr=lr)
32
        self.scheduler = CosineAnnealingLR(self.optimizer, T_max=num_epochs, eta_min=5e-6)
33
        self.best_loss = float('inf')
34
        self.phases = ['train', 'val']
35
        self.dataloaders = {
36
            phase: get_dataloader(phase, batch_size) for phase in self.phases
37
        }
38
        self.train_df_logs = pd.DataFrame()
39
        self.val_df_logs = pd.DataFrame()
40
    
41
    def _train_epoch(self, phase):
42
        print(f"{phase} mode | time: {time.strftime('%H:%M:%S')}")
43
        
44
        self.net.train() if phase == 'train' else self.net.eval()
45
        meter = Meter()
46
        meter.init_metrics()
47
        
48
        for i, (data, target) in enumerate(self.dataloaders[phase]):
49
            data = data.to(config.device)
50
            target = target.to(config.device)
51
            
52
            output = self.net(data)
53
            loss = self.criterion(output, target)
54
                        
55
            if phase == 'train':
56
                self.optimizer.zero_grad()
57
                loss.backward()
58
                self.optimizer.step()
59
            
60
            meter.update(output, target, loss.item())
61
        
62
        metrics = meter.get_metrics()
63
        metrics = {k:v / i for k, v in metrics.items()}
64
        df_logs = pd.DataFrame([metrics])
65
        confusion_matrix = meter.get_confusion_matrix()
66
        
67
        if phase == 'train':
68
            self.train_df_logs = pd.concat([self.train_df_logs, df_logs], axis=0)
69
        else:
70
            self.val_df_logs = pd.concat([self.val_df_logs, df_logs], axis=0)
71
        
72
        # show logs
73
        print('{}: {}, {}: {}, {}: {}, {}: {}, {}: {}'
74
              .format(*(x for kv in metrics.items() for x in kv))
75
             )
76
        fig, ax = plt.subplots(figsize=(5, 5))
77
        cm_ = ax.imshow(confusion_matrix, cmap='hot')
78
        ax.set_title('Confusion matrix', fontsize=15)
79
        ax.set_xlabel('Actual', fontsize=13)
80
        ax.set_ylabel('Predicted', fontsize=13)
81
        plt.colorbar(cm_)
82
        plt.show()
83
        
84
        return loss
85
    
86
    def run(self):
87
        for epoch in range(self.num_epochs):
88
            self._train_epoch(phase='train')
89
            with torch.no_grad():
90
                val_loss = self._train_epoch(phase='val')
91
                self.scheduler.step()
92
            
93
            if val_loss < self.best_loss:
94
                self.best_loss = val_loss
95
                print('\nNew checkpoint\n')
96
                self.best_loss = val_loss
97
                torch.save(self.net.state_dict(), f"best_model_epoc{epoch}.pth")
98
99
              
100
if __name__ == '__main__':
101
    # init config and set random seed
102
    config = Config()
103
    seed_everything(config.seed)
104
     
105
    # init model
106
    #model = RNNAttentionModel(1, 64, 'lstm', False)
107
    #model = RNNModel(1, 64, 'lstm', True)
108
    model = CNN(num_classes=5, hid_size=128)  
109
              
110
    # start train
111
    trainer = Trainer(net=model, lr=1e-3, batch_size=96, num_epochs=30)
112
    trainer.run()  
113
              
114
    # write logs
115
    train_logs = trainer.train_df_logs
116
    train_logs.columns = ["train_"+ colname for colname in train_logs.columns]
117
    val_logs = trainer.val_df_logs
118
    val_logs.columns = ["val_"+ colname for colname in val_logs.columns]
119
120
    logs = pd.concat([train_logs,val_logs], axis=1)
121
    logs.reset_index(drop=True, inplace=True)
122
    logs = logs.loc[:, [
123
        'train_loss', 'val_loss', 
124
        'train_accuracy', 'val_accuracy', 
125
        'train_f1', 'val_f1',
126
        'train_precision', 'val_precision',
127
        'train_recall', 'val_recall']
128
                                     ]
129
    print(logs.head())
130
    logs.to_csv('cnn.csv', index=False)
131