|
a |
|
b/run_test.py |
|
|
1 |
import logging |
|
|
2 |
import sys |
|
|
3 |
import cProfile |
|
|
4 |
|
|
|
5 |
from model_tester import FeaturePipeline, test_model |
|
|
6 |
from sklearn.pipeline import FeatureUnion |
|
|
7 |
from sklearn.feature_extraction.text import CountVectorizer |
|
|
8 |
|
|
|
9 |
from baseline_transformer import GetConcatenatedNotesTransformer, GetLatestNotesTransformer, GetEncountersFeaturesTransformer, GetLabsCountsDictTransformer, GetLabsLowCountsDictTransformer, GetLabsHighCountsDictTransformer, GetLabsLatestHighDictTransformer, GetLabsLatestLowDictTransformer, GetLabsHistoryDictTransformer |
|
|
10 |
from extract_data import get_doc_rel_dates, get_operation_date, get_ef_values |
|
|
11 |
from extract_data import get_operation_date, is_note_doc, get_date_key |
|
|
12 |
from icd_transformer import ICD9_Transformer |
|
|
13 |
from doc2vec_transformer import Doc2Vec_Note_Transformer |
|
|
14 |
from value_extractor_transformer import EFTransformer, LBBBTransformer, SinusRhythmTransformer, QRSTransformer |
|
|
15 |
from language_processing import parse_date |
|
|
16 |
|
|
|
17 |
def main(): |
|
|
18 |
features = FeatureUnion([ |
|
|
19 |
('Dia', icd9 ), |
|
|
20 |
('EF', EFTransformer('all', 1, None)), |
|
|
21 |
('EF', EFTransformer('mean', 5, None)), |
|
|
22 |
('EF', EFTransformer('max', 5, None)), |
|
|
23 |
('LBBB', LBBBTransformer()), |
|
|
24 |
#('SR', SinusRhythmTransformer()), |
|
|
25 |
#('Car_Doc2Vec', Doc2Vec_Note_Transformer('Car', 'doc2vec_models/car_1.model', 10, dbow_file='doc2vec_models/car_dbow.model')) |
|
|
26 |
# ('QRS', QRSTransformer('all', 1, None)),#Bugs with QRS |
|
|
27 |
('car_ngram', FeaturePipeline([ |
|
|
28 |
('notes_car', GetConcatenatedNotesTransformer(note_type='Car',look_back_months=12)), |
|
|
29 |
('ngram_car', CountVectorizer(ngram_range=(2, 2), min_df=.05)) |
|
|
30 |
])) |
|
|
31 |
#('Car', FeaturePipeline([ |
|
|
32 |
# ('notes_transformer_car', GetConcatenatedNotesTransformer('Car')), |
|
|
33 |
# ('tfidf', car_tfidf) |
|
|
34 |
#])), |
|
|
35 |
#('Lno', FeaturePipeline([ |
|
|
36 |
# ('notes_transformer_lno', GetConcatenatedNotesTransformer('Lno')), |
|
|
37 |
# ('tfidf', lno_tfidf) |
|
|
38 |
#])), |
|
|
39 |
#('Enc', enc), |
|
|
40 |
#('Labs_Counts',FeaturePipeline([ |
|
|
41 |
# ('labs_counts_transformer', GetLabsCountsDictTransformer()), |
|
|
42 |
# ('dict_vectorizer', DictVectorizer()) |
|
|
43 |
#])), |
|
|
44 |
#('Labs_Low_Counts',FeaturePipeline([ |
|
|
45 |
# ('labs_low_counts_transformer', GetLabsLowCountsDictTransformer()), |
|
|
46 |
# ('dict_vectorizer', DictVectorizer()) |
|
|
47 |
#])), |
|
|
48 |
#('Labs_High_Counts', FeaturePipeline([ |
|
|
49 |
# ('labs_high_counts_transformer', GetLabsHighCountsDictTransformer()), |
|
|
50 |
# ('dict_vectorizer', DictVectorizer()) |
|
|
51 |
#])), |
|
|
52 |
#('Labs_Latest_Low', FeaturePipeline([ |
|
|
53 |
# ('labs_latest_low_transformer', GetLabsLatestLowDictTransformer()), |
|
|
54 |
# ('dict_vectorizer', DictVectorizer()) |
|
|
55 |
#])), |
|
|
56 |
#('Labs_Latest_High',FeaturePipeline([ |
|
|
57 |
# ('labs_latest_high_transformer', GetLabsLatestHighDictTransformer()), |
|
|
58 |
# ('dict_vectorizer', DictVectorizer()) |
|
|
59 |
#])), |
|
|
60 |
# ('Labs_History', FeaturePipeline([ |
|
|
61 |
# ('labs_history_transformer', GetLabsHistoryDictTransformer([1])), |
|
|
62 |
# ('dict_vectorizer', DictVectorizer()) |
|
|
63 |
# ])), |
|
|
64 |
]) |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
if len(sys.argv) > 1 and unicode(sys.argv[1]).isnumeric(): |
|
|
68 |
data_size = min(906, int(sys.argv[1])) |
|
|
69 |
else: |
|
|
70 |
data_size = 25 |
|
|
71 |
|
|
|
72 |
if len(sys.argv) > 2 and unicode(sys.argv[2]).isnumeric(): |
|
|
73 |
num_cv_splits = int(sys.argv[2]) |
|
|
74 |
else: |
|
|
75 |
num_cv_splits = 2 |
|
|
76 |
|
|
|
77 |
method = 'lr' |
|
|
78 |
#method = 'svm' |
|
|
79 |
|
|
|
80 |
show_progress = True |
|
|
81 |
|
|
|
82 |
test_model(features, data_size, num_cv_splits, method, show_progress) |
|
|
83 |
|
|
|
84 |
if __name__ == '__main__': |
|
|
85 |
|
|
|
86 |
# Configure logging |
|
|
87 |
logger = logging.getLogger("DaemonLog") |
|
|
88 |
logger.setLevel(logging.INFO) |
|
|
89 |
|
|
|
90 |
out = logging.StreamHandler(sys.stdout) |
|
|
91 |
out.setLevel(logging.INFO) |
|
|
92 |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
93 |
out.setFormatter(formatter) |
|
|
94 |
|
|
|
95 |
logger.addHandler(out) |
|
|
96 |
main() |