|
a |
|
b/Inference/run_model.py |
|
|
1 |
import argparse |
|
|
2 |
import h5py |
|
|
3 |
import numpy as np |
|
|
4 |
import os |
|
|
5 |
from functools import reduce |
|
|
6 |
import pandas as pd |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
from matplotlib.widgets import Slider, Button, RadioButtons |
|
|
9 |
import sys |
|
|
10 |
from pyqtgraph.Qt import QtGui, QtCore |
|
|
11 |
|
|
|
12 |
import torch |
|
|
13 |
from torch.utils.data import TensorDataset,DataLoader |
|
|
14 |
|
|
|
15 |
from PyQT_Plot import create_dashboard |
|
|
16 |
from preprocess_data import data_read,windowing_and_resampling_hr,windowing_and_resampling_br |
|
|
17 |
from utils import load_model_HR,load_model_BR,compute_heart_rate |
|
|
18 |
|
|
|
19 |
def main(args): |
|
|
20 |
|
|
|
21 |
preprocessed_patient_data = data_read(args) |
|
|
22 |
print('-------- Data Acquisition Complete --------') |
|
|
23 |
windowed_patient_overlap,windowed_patient = windowing_and_resampling_hr(preprocessed_patient_data) |
|
|
24 |
print('-------- Pre-processing Complete for HR---------') |
|
|
25 |
windowed_patient_overlap_br = windowing_and_resampling_br(preprocessed_patient_data) |
|
|
26 |
print('-------- Pre-processing Complete for BR---------') |
|
|
27 |
|
|
|
28 |
### |
|
|
29 |
patient_ecg = np.asarray(windowed_patient_overlap['ecg'][0][:60]) |
|
|
30 |
actual_ecg_windows = np.asarray(windowed_patient['ecg'][0][:60]) |
|
|
31 |
|
|
|
32 |
patient_ecg_br = np.asarray(windowed_patient_overlap_br['ecg'][0][:60]) |
|
|
33 |
### |
|
|
34 |
|
|
|
35 |
batch_len = 32 |
|
|
36 |
batch_len_br = 1 |
|
|
37 |
window_size = 5000 |
|
|
38 |
|
|
|
39 |
patient_ecg = torch.from_numpy(patient_ecg).view(patient_ecg.shape[0],1,patient_ecg.shape[1]).float() |
|
|
40 |
input_ecg = TensorDataset(patient_ecg) |
|
|
41 |
testloader = DataLoader(input_ecg,batch_len) |
|
|
42 |
|
|
|
43 |
patient_ecg_br = torch.from_numpy(patient_ecg_br).view(patient_ecg_br.shape[0],1,patient_ecg_br.shape[1]).float() |
|
|
44 |
input_ecg_br = TensorDataset(patient_ecg_br) |
|
|
45 |
testloader_br = DataLoader(input_ecg_br,batch_len_br) |
|
|
46 |
|
|
|
47 |
SAVED_HR_MODEL_PATH = args.saved_hr_model_path |
|
|
48 |
SAVED_BR_MODEL_PATH = args.saved_br_model_path |
|
|
49 |
device = args.device |
|
|
50 |
|
|
|
51 |
ecg_peak_locs = load_model_HR(SAVED_HR_MODEL_PATH,testloader,device,batch_len,window_size) |
|
|
52 |
br_peak_locs = load_model_BR(SAVED_BR_MODEL_PATH,testloader_br,device,batch_len,window_size) |
|
|
53 |
|
|
|
54 |
### Finding Stored Paths |
|
|
55 |
save_dir = args.save_dir |
|
|
56 |
if not(os.path.isdir(save_dir)): |
|
|
57 |
os.mkdir(save_dir) |
|
|
58 |
|
|
|
59 |
save_path = save_dir + '/r_peaks_patient_' + str(args.patient_no) + '.csv' |
|
|
60 |
|
|
|
61 |
all_hr = [] |
|
|
62 |
initial_hr = len([peak for peak in list(ecg_peak_locs) if peak < 5000 * 6]) |
|
|
63 |
|
|
|
64 |
for i in range(patient_ecg.shape[0]): |
|
|
65 |
all_hr.append( len([peak for peak in list(ecg_peak_locs) if peak > i * 2500 and peak < (i * 2500 ) + 5000 * 6 ])) |
|
|
66 |
unique = np.unique(np.asarray(all_hr)) |
|
|
67 |
peak_no = np.linspace(1,len(ecg_peak_locs),len(ecg_peak_locs)).astype(int) |
|
|
68 |
peak_no = peak_no.reshape(-1,1) |
|
|
69 |
ecg_peak_locs = ecg_peak_locs.reshape(-1,1) |
|
|
70 |
ecg_peak_locs = np.hstack((peak_no,ecg_peak_locs)) |
|
|
71 |
|
|
|
72 |
pd.DataFrame(ecg_peak_locs).to_csv(save_path , header=None, index=None) |
|
|
73 |
print('-------- R Peaks Saved --------') |
|
|
74 |
|
|
|
75 |
all_br = [] |
|
|
76 |
initial_br = len([peak for peak in list(br_peak_locs) if peak < 1250 * 6]) |
|
|
77 |
for i in range(patient_ecg.shape[0]): |
|
|
78 |
all_br.append( len([peak for peak in list(br_peak_locs) if peak > i * 625 and peak < (i * 625 ) + 1250 * 6 ])) |
|
|
79 |
# all_br.append( len([peak for peak in list(br_peak_locs) if peak > i * 2500 and peak < (i * 2500 ) + 5000 * 6 ])) |
|
|
80 |
|
|
|
81 |
i = 1 |
|
|
82 |
scatter_peak = [] |
|
|
83 |
scatter_peak_1 = [] |
|
|
84 |
ecg_point = [] |
|
|
85 |
ecg_point_1 = [] |
|
|
86 |
k = 0 |
|
|
87 |
hr = [] |
|
|
88 |
peak_locs = ecg_peak_locs[:,1] |
|
|
89 |
for j in range(len(peak_locs)): |
|
|
90 |
if(peak_locs[j] < 5000*i): |
|
|
91 |
scatter_peak.append(peak_locs[j]-5000*(i-1)) |
|
|
92 |
if(i< len(actual_ecg_windows)): |
|
|
93 |
ecg_point.append(actual_ecg_windows[i-1,scatter_peak[k]]) |
|
|
94 |
k = k+1 |
|
|
95 |
elif(peak_locs[j] >= 5000*i): |
|
|
96 |
scatter_peak_1.append(np.asarray(scatter_peak)) |
|
|
97 |
hr.append(compute_heart_rate(scatter_peak_1[i-1])) |
|
|
98 |
ecg_point_1.append(np.asarray(ecg_point)) |
|
|
99 |
scatter_peak = [] |
|
|
100 |
ecg_point = [] |
|
|
101 |
i = i+1 |
|
|
102 |
scatter_peak.append(peak_locs[j]-5000*(i-1)) |
|
|
103 |
k = 0 |
|
|
104 |
if(i< len(actual_ecg_windows)): |
|
|
105 |
ecg_point.append(actual_ecg_windows[i-1,scatter_peak[k]]) |
|
|
106 |
k = k+1 |
|
|
107 |
import pdb;pdb.set_trace() |
|
|
108 |
if(args.viewer): |
|
|
109 |
create_dashboard(actual_ecg_windows,scatter_peak_1,all_hr,all_br) |
|
|
110 |
|
|
|
111 |
|
|
|
112 |
|
|
|
113 |
|
|
|
114 |
if __name__ == "__main__": |
|
|
115 |
|
|
|
116 |
parser = argparse.ArgumentParser() |
|
|
117 |
parser.add_argument('--path_dir',help = 'Path to all the records') |
|
|
118 |
parser.add_argument('--saved_hr_model_path',help = 'Path to saved Heart rate model') |
|
|
119 |
parser.add_argument('--saved_br_model_path',help = 'Path to saved breathing rate model') |
|
|
120 |
parser.add_argument('--patient_no',default = 8,type = int,help = 'Patient used for testing') |
|
|
121 |
parser.add_argument('--device',default = 'cuda', help = 'cpu/cuda') |
|
|
122 |
parser.add_argument('--save_dir',default = 'saved_models/',help = 'Directory used for saving') |
|
|
123 |
parser.add_argument('--viewer',default = 0,type = int, help = 'To view ECG plot: 1, else: 0') |
|
|
124 |
|
|
|
125 |
args = parser.parse_args() |
|
|
126 |
|
|
|
127 |
main(args) |
|
|
128 |
|