Diff of /Inference/run_model.py [000000] .. [c0487b]

Switch to side-by-side view

--- a
+++ b/Inference/run_model.py
@@ -0,0 +1,128 @@
+import argparse
+import h5py 
+import numpy as np
+import os
+from functools import reduce
+import pandas as pd
+import matplotlib.pyplot as plt
+from matplotlib.widgets import Slider, Button, RadioButtons
+import sys
+from pyqtgraph.Qt import QtGui, QtCore
+
+import torch
+from torch.utils.data import TensorDataset,DataLoader
+
+from PyQT_Plot import create_dashboard
+from preprocess_data import data_read,windowing_and_resampling_hr,windowing_and_resampling_br
+from utils import load_model_HR,load_model_BR,compute_heart_rate
+
+def main(args):
+    
+    preprocessed_patient_data = data_read(args)
+    print('-------- Data Acquisition Complete --------')
+    windowed_patient_overlap,windowed_patient = windowing_and_resampling_hr(preprocessed_patient_data)
+    print('-------- Pre-processing Complete for HR---------')
+    windowed_patient_overlap_br = windowing_and_resampling_br(preprocessed_patient_data)
+    print('-------- Pre-processing Complete for BR---------')
+
+    ###
+    patient_ecg = np.asarray(windowed_patient_overlap['ecg'][0][:60])
+    actual_ecg_windows = np.asarray(windowed_patient['ecg'][0][:60])
+    
+    patient_ecg_br = np.asarray(windowed_patient_overlap_br['ecg'][0][:60])
+    ###
+
+    batch_len = 32
+    batch_len_br = 1
+    window_size = 5000
+
+    patient_ecg = torch.from_numpy(patient_ecg).view(patient_ecg.shape[0],1,patient_ecg.shape[1]).float()
+    input_ecg = TensorDataset(patient_ecg)
+    testloader = DataLoader(input_ecg,batch_len)
+
+    patient_ecg_br = torch.from_numpy(patient_ecg_br).view(patient_ecg_br.shape[0],1,patient_ecg_br.shape[1]).float()
+    input_ecg_br = TensorDataset(patient_ecg_br)
+    testloader_br = DataLoader(input_ecg_br,batch_len_br)
+
+    SAVED_HR_MODEL_PATH = args.saved_hr_model_path
+    SAVED_BR_MODEL_PATH = args.saved_br_model_path
+    device = args.device
+    
+    ecg_peak_locs = load_model_HR(SAVED_HR_MODEL_PATH,testloader,device,batch_len,window_size)     
+    br_peak_locs = load_model_BR(SAVED_BR_MODEL_PATH,testloader_br,device,batch_len,window_size)
+
+    ### Finding Stored Paths
+    save_dir = args.save_dir
+    if not(os.path.isdir(save_dir)):
+        os.mkdir(save_dir)
+
+    save_path =  save_dir + '/r_peaks_patient_' + str(args.patient_no) + '.csv'
+
+    all_hr = []
+    initial_hr = len([peak for peak in list(ecg_peak_locs) if peak < 5000 * 6])
+    
+    for i in range(patient_ecg.shape[0]):
+        all_hr.append( len([peak for peak in list(ecg_peak_locs) if peak > i * 2500 and peak < (i * 2500 ) + 5000 * 6 ]))
+    unique = np.unique(np.asarray(all_hr))
+    peak_no = np.linspace(1,len(ecg_peak_locs),len(ecg_peak_locs)).astype(int)
+    peak_no = peak_no.reshape(-1,1)
+    ecg_peak_locs = ecg_peak_locs.reshape(-1,1) 
+    ecg_peak_locs = np.hstack((peak_no,ecg_peak_locs))
+
+    pd.DataFrame(ecg_peak_locs).to_csv(save_path , header=None, index=None)  
+    print('-------- R Peaks Saved --------')
+
+    all_br = []
+    initial_br = len([peak for peak in list(br_peak_locs) if peak < 1250 * 6])
+    for i in range(patient_ecg.shape[0]):
+        all_br.append( len([peak for peak in list(br_peak_locs) if peak > i * 625 and peak < (i * 625 ) + 1250 * 6 ]))
+        # all_br.append( len([peak for peak in list(br_peak_locs) if peak > i * 2500 and peak < (i * 2500 ) + 5000 * 6 ]))
+
+    i = 1
+    scatter_peak = []
+    scatter_peak_1 = []
+    ecg_point = []
+    ecg_point_1 = []
+    k = 0
+    hr = []
+    peak_locs = ecg_peak_locs[:,1]
+    for j in range(len(peak_locs)):     
+        if(peak_locs[j] < 5000*i):
+            scatter_peak.append(peak_locs[j]-5000*(i-1))
+            if(i< len(actual_ecg_windows)):
+                ecg_point.append(actual_ecg_windows[i-1,scatter_peak[k]])
+                k = k+1                         
+        elif(peak_locs[j] >= 5000*i):
+            scatter_peak_1.append(np.asarray(scatter_peak))
+            hr.append(compute_heart_rate(scatter_peak_1[i-1]))
+            ecg_point_1.append(np.asarray(ecg_point))                     
+            scatter_peak = []
+            ecg_point = []
+            i = i+1
+            scatter_peak.append(peak_locs[j]-5000*(i-1))
+            k = 0
+            if(i< len(actual_ecg_windows)):
+                ecg_point.append(actual_ecg_windows[i-1,scatter_peak[k]])
+                k = k+1
+    import pdb;pdb.set_trace()
+    if(args.viewer):
+        create_dashboard(actual_ecg_windows,scatter_peak_1,all_hr,all_br)
+    
+    
+        
+
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--path_dir',help = 'Path to all the records')
+    parser.add_argument('--saved_hr_model_path',help = 'Path to saved Heart rate model')
+    parser.add_argument('--saved_br_model_path',help = 'Path to saved breathing rate model')
+    parser.add_argument('--patient_no',default = 8,type = int,help = 'Patient used for testing')
+    parser.add_argument('--device',default = 'cuda', help = 'cpu/cuda')
+    parser.add_argument('--save_dir',default = 'saved_models/',help = 'Directory used for saving')
+    parser.add_argument('--viewer',default = 0,type = int, help = 'To view ECG plot: 1, else: 0')
+    
+    args = parser.parse_args()
+
+    main(args)
+