|
a |
|
b/tester_MITDB.py |
|
|
1 |
import numpy as np |
|
|
2 |
import pandas as pd |
|
|
3 |
import wfdb |
|
|
4 |
import _tester_utils |
|
|
5 |
import pathlib |
|
|
6 |
import os |
|
|
7 |
from ecgdetectors import Detectors |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class MITDB_test: |
|
|
11 |
""" |
|
|
12 |
Benchmarks detectors with against the MITDB. |
|
|
13 |
You need to download both the MIT arrhythmia database from: https://alpha.physionet.org/content/mitdb/1.0.0/ |
|
|
14 |
and needs to be placed below this directory: "../mit-bih-arrhythmia-database-1.0.0/" |
|
|
15 |
""" |
|
|
16 |
def __init__(self): |
|
|
17 |
current_dir = pathlib.Path(__file__).resolve() |
|
|
18 |
self.mitdb_dir = str(pathlib.Path(current_dir).parents[1]/'mit-bih-arrhythmia-database-1.0.0') |
|
|
19 |
|
|
|
20 |
def single_classifier_test(self, detector, tolerance=0): |
|
|
21 |
max_delay_in_samples = 350 / 5 |
|
|
22 |
dat_files = [] |
|
|
23 |
for file in os.listdir(self.mitdb_dir): |
|
|
24 |
if file.endswith(".dat"): |
|
|
25 |
dat_files.append(file) |
|
|
26 |
|
|
|
27 |
mit_records = [w.replace(".dat", "") for w in dat_files] |
|
|
28 |
|
|
|
29 |
results = np.zeros((len(mit_records), 5), dtype=int) |
|
|
30 |
|
|
|
31 |
i = 0 |
|
|
32 |
for record in mit_records: |
|
|
33 |
progress = int(i/float(len(mit_records))*100.0) |
|
|
34 |
print("MITDB progress: %i%%" % progress) |
|
|
35 |
|
|
|
36 |
sig, fields = wfdb.rdsamp(self.mitdb_dir+'/'+record) |
|
|
37 |
unfiltered_ecg = sig[:, 0] |
|
|
38 |
|
|
|
39 |
ann = wfdb.rdann(str(self.mitdb_dir+'/'+record), 'atr') |
|
|
40 |
anno = _tester_utils.sort_MIT_annotations(ann) |
|
|
41 |
|
|
|
42 |
r_peaks = detector(unfiltered_ecg) |
|
|
43 |
|
|
|
44 |
delay = _tester_utils.calcMedianDelay(r_peaks, unfiltered_ecg, max_delay_in_samples) |
|
|
45 |
|
|
|
46 |
if delay > 1: |
|
|
47 |
|
|
|
48 |
TP, FP, FN = _tester_utils.evaluate_detector(r_peaks, anno, delay, tol=tolerance) |
|
|
49 |
TN = len(unfiltered_ecg)-(TP+FP+FN) |
|
|
50 |
|
|
|
51 |
results[i, 0] = int(record) |
|
|
52 |
results[i, 1] = TP |
|
|
53 |
results[i, 2] = FP |
|
|
54 |
results[i, 3] = FN |
|
|
55 |
results[i, 4] = TN |
|
|
56 |
|
|
|
57 |
i = i+1 |
|
|
58 |
|
|
|
59 |
return results |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
def classifer_test_all(self, tolerance=0): |
|
|
63 |
|
|
|
64 |
det_names = ['two_average', 'matched_filter', 'swt', 'engzee', 'christov', 'hamilton', 'pan_tompkins'] |
|
|
65 |
output_names = ['TP', 'FP', 'FN', 'TN'] |
|
|
66 |
|
|
|
67 |
total_records = 0 |
|
|
68 |
for file in os.listdir(self.mitdb_dir): |
|
|
69 |
if file.endswith(".dat"): |
|
|
70 |
total_records = total_records + 1 |
|
|
71 |
|
|
|
72 |
total_results = np.zeros((total_records, 4*len(det_names)), dtype=int) |
|
|
73 |
|
|
|
74 |
counter = 0 |
|
|
75 |
for det_name in det_names: |
|
|
76 |
|
|
|
77 |
print('\n'+det_name) |
|
|
78 |
|
|
|
79 |
result = self.single_classifier_test(_tester_utils.det_from_name(det_name, 360), tolerance=tolerance) |
|
|
80 |
index_labels = result[:, 0] |
|
|
81 |
result = result[:, 1:] |
|
|
82 |
|
|
|
83 |
total_results[:, counter:counter+4] = result |
|
|
84 |
|
|
|
85 |
counter = counter+4 |
|
|
86 |
|
|
|
87 |
col_labels = [] |
|
|
88 |
|
|
|
89 |
for det_name in det_names: |
|
|
90 |
for output_name in output_names: |
|
|
91 |
label = det_name+" "+output_name |
|
|
92 |
col_labels.append(label) |
|
|
93 |
|
|
|
94 |
total_results_pd = pd.DataFrame(total_results, index_labels, col_labels, dtype=int) |
|
|
95 |
total_results_pd.to_csv('results_MITDB'+'.csv', sep=',') |
|
|
96 |
|
|
|
97 |
return total_results_pd |