Diff of /michael_tester.py [000000] .. [8d2107]

Switch to unified view

a b/michael_tester.py
1
import sys
2
import cProfile
3
4
from model_tester import FeaturePipeline, test_model
5
from sklearn.pipeline import FeatureUnion
6
from sklearn.feature_extraction import DictVectorizer
7
from sklearn.feature_extraction.text import TfidfTransformer
8
from baseline_transformer import GetConcatenatedNotesTransformer, GetLatestNotesTransformer, GetEncountersFeaturesTransformer, GetLabsCountsDictTransformer, GetLabsLowCountsDictTransformer, GetLabsHighCountsDictTransformer, GetLabsLatestHighDictTransformer, GetLabsLatestLowDictTransformer, GetLabsHistoryDictTransformer
9
from extract_data import get_doc_rel_dates, get_operation_date, get_ef_values
10
from extract_data import get_operation_date,  is_note_doc, get_date_key
11
from icd_transformer import ICD9_Transformer
12
from value_extractor_transformer import EFTransformer, LBBBTransformer, SinusRhythmTransformer, QRSTransformer, NYHATransformer, NICMTransformer
13
from language_processing import parse_date 
14
15
def main():
16
17
    transformer_list = []
18
19
    regex_features = True
20
    icd9_features = False
21
    labs_features = False
22
    text_features = False
23
24
    if regex_features:
25
        transformer_list += [ 
26
                    ('EF', EFTransformer('all', 1, None)),
27
                    ('EF', EFTransformer('mean', 5, None)),
28
                    ('EF', EFTransformer('max', 5, None)),
29
                    ('LBBB', LBBBTransformer(30*3)),
30
                    ('SR', SinusRhythmTransformer(30*3)),
31
                    ('NYHA', NYHATransformer(30*3)),
32
                    ('NICM', NICMTransformer(30*3)),
33
                    ('QRS', QRSTransformer('all', 1, None)),
34
                    ('QRS', QRSTransformer('mean', 5, None)),
35
                ]
36
    if icd9_features:
37
        transformer_list += [
38
                    ('Dia', ICD9_Transformer())
39
                ]
40
    if text_features:
41
        transformer_list += [
42
                    ('Car', FeaturePipeline([
43
                        ('notes_transformer_car', GetConcatenatedNotesTransformer('Car')),
44
                        ('tfidf', TfidfTransformer())
45
                    ])),
46
                    ('Lno', FeaturePipeline([
47
                       ('notes_transformer_lno', GetConcatenatedNotesTransformer('Lno')),
48
                       ('tfidf', TfidfTransformer)
49
                    ]))
50
                ]
51
    if labs_features:
52
        transformer_list += [
53
                    ('Enc', GetEncountersFeaturesTransformer(5)),
54
                    ('Labs_Counts',FeaturePipeline([
55
                        ('labs_counts_transformer', GetLabsCountsDictTransformer()),
56
                        ('dict_vectorizer', DictVectorizer())
57
                    ])),
58
                    ('Labs_Low_Counts',FeaturePipeline([
59
                        ('labs_low_counts_transformer', GetLabsLowCountsDictTransformer()),
60
                       ('dict_vectorizer', DictVectorizer())
61
                    ])),
62
                    ('Labs_High_Counts', FeaturePipeline([
63
                        ('labs_high_counts_transformer', GetLabsHighCountsDictTransformer()),
64
                        ('dict_vectorizer', DictVectorizer())
65
                    ])),
66
                    ('Labs_Latest_Low', FeaturePipeline([
67
                        ('labs_latest_low_transformer', GetLabsLatestLowDictTransformer()),
68
                        ('dict_vectorizer', DictVectorizer())
69
                    ])),
70
                    ('Labs_Latest_High',FeaturePipeline([
71
                        ('labs_latest_high_transformer', GetLabsLatestHighDictTransformer()),
72
                        ('dict_vectorizer', DictVectorizer())
73
                    ])),
74
                    ('Labs_History', FeaturePipeline([
75
                        ('labs_history_transformer', GetLabsHistoryDictTransformer([1])),
76
                        ('dict_vectorizer', DictVectorizer())
77
                    ]))
78
                ]
79
80
    
81
    features = FeatureUnion(transformer_list)
82
83
    if len(sys.argv) > 1 and unicode(sys.argv[1]).isnumeric():
84
        data_size = min(int(sys.argv[1]), 906)
85
    else:
86
        data_size = 25
87
88
    if len(sys.argv) > 2 and unicode(sys.argv[2]).isnumeric():
89
        num_cv_splits = int(sys.argv[2])
90
    else:
91
        num_cv_splits = 5
92
93
    print "Data size: " + str(data_size)
94
    print "CV splits: " + str(num_cv_splits)
95
96
    if len(sys.argv) > 3:
97
        method = sys.argv[3]
98
    else:
99
        method = 'adaboost'
100
101
    #method = 'lr'
102
    #method = 'svm'
103
    method = 'adaboost'
104
    #method = 'cdm'
105
106
    model_args = dict()
107
    if method in ['lr', 'svm']:
108
        if len(sys.argv) > 4 and unicode(sys.argv[4]).isnumeric():
109
            model_args['regularization'] = float(sys.argv[4])
110
        else:
111
            model_args['regularization'] = 0.
112
    if method == 'adaboost':
113
        if len(sys.argv) > 4 and unicode(sys.argv[4]).isnumeric():
114
            model_args['n_estimators'] = int(sys.argv[4])
115
        else:
116
            model_args['n_estimators'] = 50
117
        
118
119
    show_progress = True
120
    print 'Method:', method
121
    test_model(features, data_size, num_cv_splits, method, show_progress, model_args)
122
123
if __name__ == '__main__':
124
    main()