|
a |
|
b/ecgtoBR/train.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from torch.utils.data import TensorDataset,DataLoader |
|
|
4 |
from torch.utils.tensorboard import SummaryWriter |
|
|
5 |
|
|
|
6 |
from tqdm import tqdm |
|
|
7 |
import os |
|
|
8 |
import random |
|
|
9 |
import argparse |
|
|
10 |
|
|
|
11 |
from BRnet import IncUNet |
|
|
12 |
from create_dataset import data_preprocess |
|
|
13 |
from utils import testDataEval,save_model |
|
|
14 |
|
|
|
15 |
def train(args): |
|
|
16 |
|
|
|
17 |
"""Train model to get Breathing rate from ECG |
|
|
18 |
""" |
|
|
19 |
|
|
|
20 |
PATH = 'data' |
|
|
21 |
X_train = torch.load(PATH + '/ecgtoBR_train_data.pt') |
|
|
22 |
y_train = torch.load(PATH + '/ecgtoBR_train_labels.pt') |
|
|
23 |
|
|
|
24 |
X_test = torch.load(PATH + '/ecgtoBR_test_data.pt') |
|
|
25 |
y_test = torch.load(PATH + '/ecgtoBR_test_labels.pt') |
|
|
26 |
|
|
|
27 |
BATCH_SIZE= 64 |
|
|
28 |
NUM_EPOCHS = 400 |
|
|
29 |
best_loss = 1000 |
|
|
30 |
|
|
|
31 |
train = TensorDataset(X_train,y_train) |
|
|
32 |
val = TensorDataset(X_test,y_test) |
|
|
33 |
trainLoader = DataLoader(train,batch_size = BATCH_SIZE,shuffle = True) |
|
|
34 |
valLoader = DataLoader(val, batch_size= BATCH_SIZE, shuffle=True) |
|
|
35 |
|
|
|
36 |
model = IncUNet((1,1,5000)) |
|
|
37 |
model.cuda() |
|
|
38 |
criterion = torch.nn.SmoothL1Loss() |
|
|
39 |
optim = torch.optim.Adam(model.parameters(),lr = 0.001) |
|
|
40 |
scheduler = torch.optim.lr_scheduler.MultiStepLR(optim,milestones=[100,200], gamma=0.1) |
|
|
41 |
|
|
|
42 |
NUM_EPOCHS = 400 |
|
|
43 |
best_loss = 1000 |
|
|
44 |
|
|
|
45 |
writer = SummaryWriter() |
|
|
46 |
|
|
|
47 |
if not(os.path.isdir("Saved_Model")): |
|
|
48 |
os.mkdir("Saved_Model") |
|
|
49 |
|
|
|
50 |
for epoch in tqdm(range(NUM_EPOCHS)): |
|
|
51 |
|
|
|
52 |
model.train() |
|
|
53 |
totalLoss = 0 |
|
|
54 |
|
|
|
55 |
for step,(x,y) in enumerate(trainLoader): |
|
|
56 |
|
|
|
57 |
print('.',end = " ") |
|
|
58 |
ecg= x.cuda() |
|
|
59 |
BR = y.cuda() |
|
|
60 |
BR_pred = model(ecg) |
|
|
61 |
optim.zero_grad() |
|
|
62 |
loss = criterion(BR_pred,BR) |
|
|
63 |
totalLoss += loss.cpu().item() |
|
|
64 |
loss.backward() |
|
|
65 |
optim.step() |
|
|
66 |
|
|
|
67 |
print ('') |
|
|
68 |
print ("Epoch:{} Train Loss:{}".format(epoch + 1,totalLoss/(step+1))) |
|
|
69 |
|
|
|
70 |
totalTestLoss = testDataEval(model, valLoader, criterion) |
|
|
71 |
scheduler.step() |
|
|
72 |
|
|
|
73 |
if best_loss > totalTestLoss: |
|
|
74 |
print ("........Saving Best Model........") |
|
|
75 |
best_loss = totalTestLoss |
|
|
76 |
save_model("Saved_Model", epoch, model, optim, best_loss ) |
|
|
77 |
|
|
|
78 |
writer.add_scalar("Loss/test",totalTestLoss, epoch ) |
|
|
79 |
writer.add_scalar("Loss/train",totalLoss/(step+1),epoch ) |
|
|
80 |
|
|
|
81 |
writer.close() |
|
|
82 |
|
|
|
83 |
|
|
|
84 |
if __name__ == "__main__": |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
parser = argparse.ArgumentParser() |
|
|
88 |
parser.add_argument('--preprocess_data',action = 'store_true', help = 'Use if True') |
|
|
89 |
parser.add_argument('--sampling_freq',default = 125, type = int, help = 'Sampling Frequency') |
|
|
90 |
parser.add_argument('--upsample_freq',default = 500, type = int, help = 'Resampling Frequency') |
|
|
91 |
parser.add_argument('--window_length',default = 5,type = int,help = 'Window Length in seconds') |
|
|
92 |
parser.add_argument('--data_path',help = 'Path to data') |
|
|
93 |
|
|
|
94 |
parser.add_argument('--random_seed',default = 5,type = int,help = 'Random Seed initializer') |
|
|
95 |
args = parser.parse_args() |
|
|
96 |
|
|
|
97 |
torch.backends.cudnn.deterministic = True |
|
|
98 |
torch.backends.cudnn.benchmark = False |
|
|
99 |
random.seed(args.random_seed) |
|
|
100 |
torch.manual_seed(args.random_seed) |
|
|
101 |
if torch.cuda.is_available(): |
|
|
102 |
torch.cuda.manual_seed(args.random_seed) |
|
|
103 |
|
|
|
104 |
if args.preprocess_data: |
|
|
105 |
data_preprocess(args) |
|
|
106 |
|
|
|
107 |
train(args) |