Diff of /tester_MITDB.py [000000] .. [1a3ddc]

Switch to side-by-side view

--- a
+++ b/tester_MITDB.py
@@ -0,0 +1,97 @@
+import numpy as np
+import pandas as pd
+import wfdb
+import _tester_utils
+import pathlib
+import os
+from ecgdetectors import Detectors
+
+
+class MITDB_test:
+    """
+    Benchmarks detectors with against the MITDB.
+    You need to download both the MIT arrhythmia database from: https://alpha.physionet.org/content/mitdb/1.0.0/
+    and needs to be placed below this directory: "../mit-bih-arrhythmia-database-1.0.0/"
+    """
+    def __init__(self):
+        current_dir = pathlib.Path(__file__).resolve()
+        self.mitdb_dir = str(pathlib.Path(current_dir).parents[1]/'mit-bih-arrhythmia-database-1.0.0')
+
+    def single_classifier_test(self, detector, tolerance=0):
+        max_delay_in_samples = 350 / 5
+        dat_files = []
+        for file in os.listdir(self.mitdb_dir):
+            if file.endswith(".dat"):
+                dat_files.append(file)
+        
+        mit_records = [w.replace(".dat", "") for w in dat_files]
+        
+        results = np.zeros((len(mit_records), 5), dtype=int)
+
+        i = 0
+        for record in mit_records:
+            progress = int(i/float(len(mit_records))*100.0)
+            print("MITDB progress: %i%%" % progress)
+
+            sig, fields = wfdb.rdsamp(self.mitdb_dir+'/'+record)
+            unfiltered_ecg = sig[:, 0]  
+
+            ann = wfdb.rdann(str(self.mitdb_dir+'/'+record), 'atr')    
+            anno = _tester_utils.sort_MIT_annotations(ann)    
+
+            r_peaks = detector(unfiltered_ecg)
+
+            delay = _tester_utils.calcMedianDelay(r_peaks, unfiltered_ecg, max_delay_in_samples)
+
+            if delay > 1:
+
+                TP, FP, FN = _tester_utils.evaluate_detector(r_peaks, anno, delay, tol=tolerance)
+                TN = len(unfiltered_ecg)-(TP+FP+FN)
+                
+                results[i, 0] = int(record)    
+                results[i, 1] = TP
+                results[i, 2] = FP
+                results[i, 3] = FN
+                results[i, 4] = TN
+
+            i = i+1
+        
+        return results
+
+
+    def classifer_test_all(self, tolerance=0):
+
+        det_names = ['two_average', 'matched_filter', 'swt', 'engzee', 'christov', 'hamilton', 'pan_tompkins']
+        output_names = ['TP', 'FP', 'FN', 'TN']
+
+        total_records = 0
+        for file in os.listdir(self.mitdb_dir):
+            if file.endswith(".dat"):
+                total_records = total_records + 1
+
+        total_results = np.zeros((total_records, 4*len(det_names)), dtype=int)
+
+        counter = 0
+        for det_name in det_names:
+
+            print('\n'+det_name)
+
+            result = self.single_classifier_test(_tester_utils.det_from_name(det_name, 360), tolerance=tolerance)
+            index_labels = result[:, 0]
+            result = result[:, 1:]
+
+            total_results[:, counter:counter+4] = result
+
+            counter = counter+4  
+
+        col_labels = []
+
+        for det_name in det_names:
+                for output_name in output_names:
+                    label = det_name+" "+output_name
+                    col_labels.append(label)
+
+        total_results_pd = pd.DataFrame(total_results, index_labels, col_labels, dtype=int)            
+        total_results_pd.to_csv('results_MITDB'+'.csv', sep=',')
+
+        return total_results_pd