|
a |
|
b/michael_tester.py |
|
|
1 |
import sys |
|
|
2 |
import cProfile |
|
|
3 |
|
|
|
4 |
from model_tester import FeaturePipeline, test_model |
|
|
5 |
from sklearn.pipeline import FeatureUnion |
|
|
6 |
from sklearn.feature_extraction import DictVectorizer |
|
|
7 |
from sklearn.feature_extraction.text import TfidfTransformer |
|
|
8 |
from baseline_transformer import GetConcatenatedNotesTransformer, GetLatestNotesTransformer, GetEncountersFeaturesTransformer, GetLabsCountsDictTransformer, GetLabsLowCountsDictTransformer, GetLabsHighCountsDictTransformer, GetLabsLatestHighDictTransformer, GetLabsLatestLowDictTransformer, GetLabsHistoryDictTransformer |
|
|
9 |
from extract_data import get_doc_rel_dates, get_operation_date, get_ef_values |
|
|
10 |
from extract_data import get_operation_date, is_note_doc, get_date_key |
|
|
11 |
from icd_transformer import ICD9_Transformer |
|
|
12 |
from value_extractor_transformer import EFTransformer, LBBBTransformer, SinusRhythmTransformer, QRSTransformer, NYHATransformer, NICMTransformer |
|
|
13 |
from language_processing import parse_date |
|
|
14 |
|
|
|
15 |
def main(): |
|
|
16 |
|
|
|
17 |
transformer_list = [] |
|
|
18 |
|
|
|
19 |
regex_features = True |
|
|
20 |
icd9_features = False |
|
|
21 |
labs_features = False |
|
|
22 |
text_features = False |
|
|
23 |
|
|
|
24 |
if regex_features: |
|
|
25 |
transformer_list += [ |
|
|
26 |
('EF', EFTransformer('all', 1, None)), |
|
|
27 |
('EF', EFTransformer('mean', 5, None)), |
|
|
28 |
('EF', EFTransformer('max', 5, None)), |
|
|
29 |
('LBBB', LBBBTransformer(30*3)), |
|
|
30 |
('SR', SinusRhythmTransformer(30*3)), |
|
|
31 |
('NYHA', NYHATransformer(30*3)), |
|
|
32 |
('NICM', NICMTransformer(30*3)), |
|
|
33 |
('QRS', QRSTransformer('all', 1, None)), |
|
|
34 |
('QRS', QRSTransformer('mean', 5, None)), |
|
|
35 |
] |
|
|
36 |
if icd9_features: |
|
|
37 |
transformer_list += [ |
|
|
38 |
('Dia', ICD9_Transformer()) |
|
|
39 |
] |
|
|
40 |
if text_features: |
|
|
41 |
transformer_list += [ |
|
|
42 |
('Car', FeaturePipeline([ |
|
|
43 |
('notes_transformer_car', GetConcatenatedNotesTransformer('Car')), |
|
|
44 |
('tfidf', TfidfTransformer()) |
|
|
45 |
])), |
|
|
46 |
('Lno', FeaturePipeline([ |
|
|
47 |
('notes_transformer_lno', GetConcatenatedNotesTransformer('Lno')), |
|
|
48 |
('tfidf', TfidfTransformer) |
|
|
49 |
])) |
|
|
50 |
] |
|
|
51 |
if labs_features: |
|
|
52 |
transformer_list += [ |
|
|
53 |
('Enc', GetEncountersFeaturesTransformer(5)), |
|
|
54 |
('Labs_Counts',FeaturePipeline([ |
|
|
55 |
('labs_counts_transformer', GetLabsCountsDictTransformer()), |
|
|
56 |
('dict_vectorizer', DictVectorizer()) |
|
|
57 |
])), |
|
|
58 |
('Labs_Low_Counts',FeaturePipeline([ |
|
|
59 |
('labs_low_counts_transformer', GetLabsLowCountsDictTransformer()), |
|
|
60 |
('dict_vectorizer', DictVectorizer()) |
|
|
61 |
])), |
|
|
62 |
('Labs_High_Counts', FeaturePipeline([ |
|
|
63 |
('labs_high_counts_transformer', GetLabsHighCountsDictTransformer()), |
|
|
64 |
('dict_vectorizer', DictVectorizer()) |
|
|
65 |
])), |
|
|
66 |
('Labs_Latest_Low', FeaturePipeline([ |
|
|
67 |
('labs_latest_low_transformer', GetLabsLatestLowDictTransformer()), |
|
|
68 |
('dict_vectorizer', DictVectorizer()) |
|
|
69 |
])), |
|
|
70 |
('Labs_Latest_High',FeaturePipeline([ |
|
|
71 |
('labs_latest_high_transformer', GetLabsLatestHighDictTransformer()), |
|
|
72 |
('dict_vectorizer', DictVectorizer()) |
|
|
73 |
])), |
|
|
74 |
('Labs_History', FeaturePipeline([ |
|
|
75 |
('labs_history_transformer', GetLabsHistoryDictTransformer([1])), |
|
|
76 |
('dict_vectorizer', DictVectorizer()) |
|
|
77 |
])) |
|
|
78 |
] |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
features = FeatureUnion(transformer_list) |
|
|
82 |
|
|
|
83 |
if len(sys.argv) > 1 and unicode(sys.argv[1]).isnumeric(): |
|
|
84 |
data_size = min(int(sys.argv[1]), 906) |
|
|
85 |
else: |
|
|
86 |
data_size = 25 |
|
|
87 |
|
|
|
88 |
if len(sys.argv) > 2 and unicode(sys.argv[2]).isnumeric(): |
|
|
89 |
num_cv_splits = int(sys.argv[2]) |
|
|
90 |
else: |
|
|
91 |
num_cv_splits = 5 |
|
|
92 |
|
|
|
93 |
print "Data size: " + str(data_size) |
|
|
94 |
print "CV splits: " + str(num_cv_splits) |
|
|
95 |
|
|
|
96 |
if len(sys.argv) > 3: |
|
|
97 |
method = sys.argv[3] |
|
|
98 |
else: |
|
|
99 |
method = 'adaboost' |
|
|
100 |
|
|
|
101 |
#method = 'lr' |
|
|
102 |
#method = 'svm' |
|
|
103 |
method = 'adaboost' |
|
|
104 |
#method = 'cdm' |
|
|
105 |
|
|
|
106 |
model_args = dict() |
|
|
107 |
if method in ['lr', 'svm']: |
|
|
108 |
if len(sys.argv) > 4 and unicode(sys.argv[4]).isnumeric(): |
|
|
109 |
model_args['regularization'] = float(sys.argv[4]) |
|
|
110 |
else: |
|
|
111 |
model_args['regularization'] = 0. |
|
|
112 |
if method == 'adaboost': |
|
|
113 |
if len(sys.argv) > 4 and unicode(sys.argv[4]).isnumeric(): |
|
|
114 |
model_args['n_estimators'] = int(sys.argv[4]) |
|
|
115 |
else: |
|
|
116 |
model_args['n_estimators'] = 50 |
|
|
117 |
|
|
|
118 |
|
|
|
119 |
show_progress = True |
|
|
120 |
print 'Method:', method |
|
|
121 |
test_model(features, data_size, num_cv_splits, method, show_progress, model_args) |
|
|
122 |
|
|
|
123 |
if __name__ == '__main__': |
|
|
124 |
main() |