Diff of /tester_MITDB.py [000000] .. [1a3ddc]

Switch to unified view

a b/tester_MITDB.py
1
import numpy as np
2
import pandas as pd
3
import wfdb
4
import _tester_utils
5
import pathlib
6
import os
7
from ecgdetectors import Detectors
8
9
10
class MITDB_test:
11
    """
12
    Benchmarks detectors with against the MITDB.
13
    You need to download both the MIT arrhythmia database from: https://alpha.physionet.org/content/mitdb/1.0.0/
14
    and needs to be placed below this directory: "../mit-bih-arrhythmia-database-1.0.0/"
15
    """
16
    def __init__(self):
17
        current_dir = pathlib.Path(__file__).resolve()
18
        self.mitdb_dir = str(pathlib.Path(current_dir).parents[1]/'mit-bih-arrhythmia-database-1.0.0')
19
20
    def single_classifier_test(self, detector, tolerance=0):
21
        max_delay_in_samples = 350 / 5
22
        dat_files = []
23
        for file in os.listdir(self.mitdb_dir):
24
            if file.endswith(".dat"):
25
                dat_files.append(file)
26
        
27
        mit_records = [w.replace(".dat", "") for w in dat_files]
28
        
29
        results = np.zeros((len(mit_records), 5), dtype=int)
30
31
        i = 0
32
        for record in mit_records:
33
            progress = int(i/float(len(mit_records))*100.0)
34
            print("MITDB progress: %i%%" % progress)
35
36
            sig, fields = wfdb.rdsamp(self.mitdb_dir+'/'+record)
37
            unfiltered_ecg = sig[:, 0]  
38
39
            ann = wfdb.rdann(str(self.mitdb_dir+'/'+record), 'atr')    
40
            anno = _tester_utils.sort_MIT_annotations(ann)    
41
42
            r_peaks = detector(unfiltered_ecg)
43
44
            delay = _tester_utils.calcMedianDelay(r_peaks, unfiltered_ecg, max_delay_in_samples)
45
46
            if delay > 1:
47
48
                TP, FP, FN = _tester_utils.evaluate_detector(r_peaks, anno, delay, tol=tolerance)
49
                TN = len(unfiltered_ecg)-(TP+FP+FN)
50
                
51
                results[i, 0] = int(record)    
52
                results[i, 1] = TP
53
                results[i, 2] = FP
54
                results[i, 3] = FN
55
                results[i, 4] = TN
56
57
            i = i+1
58
        
59
        return results
60
61
62
    def classifer_test_all(self, tolerance=0):
63
64
        det_names = ['two_average', 'matched_filter', 'swt', 'engzee', 'christov', 'hamilton', 'pan_tompkins']
65
        output_names = ['TP', 'FP', 'FN', 'TN']
66
67
        total_records = 0
68
        for file in os.listdir(self.mitdb_dir):
69
            if file.endswith(".dat"):
70
                total_records = total_records + 1
71
72
        total_results = np.zeros((total_records, 4*len(det_names)), dtype=int)
73
74
        counter = 0
75
        for det_name in det_names:
76
77
            print('\n'+det_name)
78
79
            result = self.single_classifier_test(_tester_utils.det_from_name(det_name, 360), tolerance=tolerance)
80
            index_labels = result[:, 0]
81
            result = result[:, 1:]
82
83
            total_results[:, counter:counter+4] = result
84
85
            counter = counter+4  
86
87
        col_labels = []
88
89
        for det_name in det_names:
90
                for output_name in output_names:
91
                    label = det_name+" "+output_name
92
                    col_labels.append(label)
93
94
        total_results_pd = pd.DataFrame(total_results, index_labels, col_labels, dtype=int)            
95
        total_results_pd.to_csv('results_MITDB'+'.csv', sep=',')
96
97
        return total_results_pd