Diff of /ecgtoBR/train.py [000000] .. [c0487b]

Switch to unified view

a b/ecgtoBR/train.py
1
import torch
2
import torch.nn as nn
3
from torch.utils.data import TensorDataset,DataLoader
4
from torch.utils.tensorboard import SummaryWriter
5
6
from tqdm import tqdm
7
import os
8
import random
9
import argparse
10
11
from BRnet import IncUNet
12
from create_dataset import data_preprocess
13
from utils import testDataEval,save_model
14
15
def train(args):
16
17
    """Train model to get Breathing rate from ECG
18
    """
19
    
20
    PATH = 'data'
21
    X_train = torch.load(PATH + '/ecgtoBR_train_data.pt')
22
    y_train = torch.load(PATH + '/ecgtoBR_train_labels.pt')
23
24
    X_test = torch.load(PATH + '/ecgtoBR_test_data.pt')
25
    y_test = torch.load(PATH + '/ecgtoBR_test_labels.pt')
26
    
27
    BATCH_SIZE= 64
28
    NUM_EPOCHS = 400
29
    best_loss = 1000
30
31
    train = TensorDataset(X_train,y_train)
32
    val = TensorDataset(X_test,y_test)
33
    trainLoader = DataLoader(train,batch_size = BATCH_SIZE,shuffle = True)
34
    valLoader = DataLoader(val, batch_size= BATCH_SIZE, shuffle=True)
35
36
    model = IncUNet((1,1,5000))
37
    model.cuda()
38
    criterion = torch.nn.SmoothL1Loss()
39
    optim = torch.optim.Adam(model.parameters(),lr = 0.001)
40
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optim,milestones=[100,200], gamma=0.1)
41
42
    NUM_EPOCHS = 400
43
    best_loss = 1000
44
45
    writer = SummaryWriter()
46
47
    if not(os.path.isdir("Saved_Model")):
48
        os.mkdir("Saved_Model")
49
50
    for epoch in tqdm(range(NUM_EPOCHS)):
51
        
52
        model.train()
53
        totalLoss = 0
54
        
55
        for step,(x,y) in enumerate(trainLoader):
56
            
57
            print('.',end = " ")
58
            ecg= x.cuda()
59
            BR = y.cuda()
60
            BR_pred = model(ecg)
61
            optim.zero_grad()
62
            loss = criterion(BR_pred,BR)
63
            totalLoss += loss.cpu().item()
64
            loss.backward()
65
            optim.step()
66
            
67
        print ('')
68
        print ("Epoch:{} Train Loss:{}".format(epoch + 1,totalLoss/(step+1)))
69
        
70
        totalTestLoss = testDataEval(model, valLoader, criterion)
71
        scheduler.step()
72
        
73
        if best_loss > totalTestLoss:
74
            print ("........Saving Best Model........")
75
            best_loss = totalTestLoss
76
            save_model("Saved_Model", epoch, model, optim, best_loss )
77
        
78
        writer.add_scalar("Loss/test",totalTestLoss, epoch )
79
        writer.add_scalar("Loss/train",totalLoss/(step+1),epoch )
80
            
81
    writer.close()
82
83
84
if __name__ == "__main__":
85
86
87
    parser = argparse.ArgumentParser()
88
    parser.add_argument('--preprocess_data',action = 'store_true', help = 'Use if True')
89
    parser.add_argument('--sampling_freq',default = 125, type = int, help = 'Sampling Frequency')
90
    parser.add_argument('--upsample_freq',default = 500, type = int, help = 'Resampling Frequency')
91
    parser.add_argument('--window_length',default = 5,type = int,help = 'Window Length in seconds')
92
    parser.add_argument('--data_path',help = 'Path to data')
93
    
94
    parser.add_argument('--random_seed',default = 5,type = int,help = 'Random Seed initializer')
95
    args = parser.parse_args()
96
    
97
    torch.backends.cudnn.deterministic = True
98
    torch.backends.cudnn.benchmark = False
99
    random.seed(args.random_seed)
100
    torch.manual_seed(args.random_seed)
101
    if torch.cuda.is_available():
102
        torch.cuda.manual_seed(args.random_seed) 
103
104
    if args.preprocess_data:
105
        data_preprocess(args)
106
107
    train(args)