Diff of /ecgtoHR/train.py [000000] .. [c0487b]

Switch to unified view

a b/ecgtoHR/train.py
1
import torch
2
import torch.nn as nn
3
from torch.utils.data import TensorDataset,DataLoader
4
from torch.utils.tensorboard import SummaryWriter
5
6
from tqdm import tqdm
7
import os
8
import random
9
import argparse
10
11
from HRnet import IncUNet
12
from create_dataset import data_preprocess
13
from utils import testDataEval,save_model
14
15
def train(args):
16
    
17
    PATH = 'data'
18
    X_train = torch.load(PATH + '/ecgtoHR_train_data.pt')
19
    y_train = torch.load(PATH + '/ecgtoHR_train_labels.pt')
20
21
    X_test = torch.load(PATH + '/ecgtoHR_test_data.pt')
22
    y_test = torch.load(PATH + '/ecgtoHR_test_labels.pt')
23
    
24
    BATCH_SIZE= 64
25
    NUM_EPOCHS = 400
26
    best_loss = 1000
27
28
    train = TensorDataset(X_train,y_train)
29
    val = TensorDataset(X_test,y_test)
30
    trainLoader = DataLoader(train,batch_size = BATCH_SIZE,shuffle = True)
31
    valLoader = DataLoader(val, batch_size= BATCH_SIZE, shuffle=True)
32
33
    model = IncUNet((1,1,5000))
34
    model.cuda()
35
    criterion = torch.nn.SmoothL1Loss()
36
    optim = torch.optim.Adam(model.parameters(),lr = 0.001)
37
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optim,milestones=[100,200], gamma=0.1)
38
39
    NUM_EPOCHS = 400
40
    best_loss = 1000
41
42
    writer = SummaryWriter()
43
44
    if not(os.path.isdir("Saved_Model")):
45
        os.mkdir("Saved_Model")
46
47
    for epoch in tqdm(range(NUM_EPOCHS)):
48
        
49
        model.train()
50
        totalLoss = 0
51
        
52
        for step,(x,y) in enumerate(trainLoader):
53
            
54
            print('.',end = " ")
55
            ecg= x.unsqueeze(1).cuda()
56
            HR = y.unsqueeze(1).cuda()
57
            HR_pred = model(ecg)
58
            optim.zero_grad()
59
            loss = criterion(HR_pred,HR)
60
            totalLoss += loss.cpu().item()
61
            loss.backward()
62
            optim.step()
63
            
64
        print ('')
65
        print ("Epoch:{} Train Loss:{}".format(epoch + 1,totalLoss/(step+1)))
66
        
67
        totalTestLoss = testDataEval(model, valLoader, criterion)
68
        scheduler.step()
69
        
70
        if best_loss > totalTestLoss:
71
            print ("........Saving Best Model........")
72
            best_loss = totalTestLoss
73
            save_model("Saved_Model", epoch, model, optim, best_loss )
74
        
75
        writer.add_scalar("Loss/test",totalTestLoss, epoch )
76
        writer.add_scalar("Loss/train",totalLoss/(step+1),epoch )
77
            
78
    writer.close()
79
80
81
if __name__ == "__main__":
82
83
    parser = argparse.ArgumentParser()
84
    parser.add_argument('--preprocess_data',action = 'store_true', help = 'Use if True')
85
    parser.add_argument('--sampling_freq',default = 125, type = int, help = 'Sampling Frequency')
86
    parser.add_argument('--upsample_freq',default = 500, type = int, help = 'Resampling Frequency')
87
    parser.add_argument('--window_length',default = 5,type = int,help = 'Window Length in seconds')
88
    parser.add_argument('--data_path',help = 'Path to dataset')
89
    
90
    parser.add_argument('--random_seed',default = 5,type = int,help = 'Random Seed initializer')
91
    args = parser.parse_args()
92
    
93
    torch.backends.cudnn.deterministic = True
94
    torch.backends.cudnn.benchmark = False
95
    random.seed(args.random_seed)
96
    torch.manual_seed(args.random_seed)
97
    if torch.cuda.is_available():
98
        torch.cuda.manual_seed(args.random_seed) 
99
100
    if args.preprocess_data:
101
        data_preprocess(args)
102
103
    train(args)