Diff of /decision_model.py [000000] .. [8d2107]

Switch to unified view

a b/decision_model.py
1
from sklearn.base import TransformerMixin, ClassifierMixin
2
from value_extractor_transformer import EFTransformer, QRSTransformer, LBBBTransformer, SinusRhythmTransformer, NYHATransformer
3
4
5
class ClinicalDecisionModel(TransformerMixin, ClassifierMixin):
6
7
    """
8
    Replicate prediction of clinical descion tree
9
    in 'Circulation-2012-Tracy-1784-800.pdf' Appendix 3
10
    Does not need to be fit and transforms data on its own
11
    """
12
13
    LVEF = 'lvef'
14
    QRS = 'qrs'
15
    LBBB = 'lbbb'
16
    SINUS_RHYTHM = 'sr'
17
    NYHA = 'nyha'
18
19
    def __init__(self):
20
        self.transformers = dict()
21
        self.transformers[self.LVEF]= EFTransformer('max',3)
22
        self.transformers[self.QRS]= QRSTransformer('mean', 3)
23
        self.transformers[self.LBBB]= LBBBTransformer()
24
        self.transformers[self.SINUS_RHYTHM]= SinusRhythmTransformer()
25
        self.transformers[self.NYHA]= NYHATransformer()
26
        
27
    def fit(self, X = None, y = None):
28
        for trans_key in self.transformers:
29
            self.transformers[trans_key].fit(X, y)
30
        return self
31
32
    def transform(self, X):
33
        return [self.__find_values(empi) for empi in X]
34
35
    def __find_values(self, empi):
36
        values = dict()
37
        for trans_key in self.transformers:
38
            values[trans_key] = self.transformers[trans_key].get_feature(empi)
39
        return values
40
41
    def predict(self, X):
42
        X_transformed = self.transform(X)
43
        predicted_colors = map(self.predict_color, X_transformed)
44
        y_hat = map(self.prediction_from_color, predicted_colors)
45
        return y_hat
46
47
    def predict_color(self, values):
48
        """
49
        Summary of logic (all have ef > 35):
50
        NYHA I:
51
            Orange - ef < 30, qrs > 150, lbbb
52
            Red -    ef > 30 or qrs < 150 or no lbbb
53
        NYHA II:
54
            Green -  qrs > 150, lbbb, sinus rhythm
55
            Yellow - qrs < 150, lbbb, sinus rhythm
56
            Orange - qrs > 150, no lbbb
57
            Red -    no lbbb and qrs < 150 or lbbb and no SR
58
        NYHA III:
59
            Green -  qrs > 150, lbbb, sinus rhythm
60
            Yellow - qrs > 150, no lbbb, sinus rhythm
61
            Orange - qrs < 150, no lbbb, sinus rhythm
62
            Red -    qrs < 120 or no sinus rhythm
63
        NYHA IV:
64
            Red -    for all cases
65
        """
66
        nyha_class = values[self.NYHA].index(1) + 1
67
        ef = values[self.LVEF][0]
68
        qrs = values[self.QRS][0]
69
        lbbb = bool(values[self.LBBB][0])
70
        sr = bool(values[self.SINUS_RHYTHM][0])
71
        if ef > 35: #EF higher than 35 disqualified
72
            return 'red'
73
        #print "NYHA:", nyha_class, "EF:", ef, "QRS:", qrs, "LBBB:", lbbb, "SR:", sr
74
        if nyha_class == 1:
75
            if ef < 30 and qrs >= 150 and lbbb:
76
                return 'orange'
77
            else:
78
                return 'red'
79
        elif nyha_class == 2:
80
            if lbbb:
81
                if sr:
82
                    if qrs > 150: 
83
                        return 'green'
84
                    else: 
85
                        return 'yellow'
86
                else: 
87
                    return 'red'
88
            else:
89
                if qrs > 150:
90
                    return 'orange'
91
                else:
92
                    return 'red'
93
        elif nyha_class == 3:
94
            if sr and qrs > 120:
95
                if qrs > 150 and lbbb:
96
                    return 'green'
97
                elif qrs <= 150 and not lbbb:
98
                    return 'orange'
99
                else:
100
                    return 'yellow'
101
            else:
102
                return 'red'
103
        else: #this is None and class 4
104
            return 'red'
105
106
    def prediction_from_color(self, color):
107
        mapping = {'red' : 0, 'orange' : 0, 'yellow': 1, 'green' : 1}
108
        return mapping[color]