Diff of /run_test.py [000000] .. [8d2107]

Switch to unified view

a b/run_test.py
1
import logging
2
import sys
3
import cProfile
4
5
from model_tester import FeaturePipeline, test_model
6
from sklearn.pipeline import FeatureUnion
7
from sklearn.feature_extraction.text import CountVectorizer
8
9
from baseline_transformer import GetConcatenatedNotesTransformer, GetLatestNotesTransformer, GetEncountersFeaturesTransformer, GetLabsCountsDictTransformer, GetLabsLowCountsDictTransformer, GetLabsHighCountsDictTransformer, GetLabsLatestHighDictTransformer, GetLabsLatestLowDictTransformer, GetLabsHistoryDictTransformer
10
from extract_data import get_doc_rel_dates, get_operation_date, get_ef_values
11
from extract_data import get_operation_date,  is_note_doc, get_date_key
12
from icd_transformer import ICD9_Transformer
13
from doc2vec_transformer import Doc2Vec_Note_Transformer
14
from value_extractor_transformer import EFTransformer, LBBBTransformer, SinusRhythmTransformer, QRSTransformer
15
from language_processing import parse_date 
16
17
def main():
18
    features = FeatureUnion([
19
                ('Dia', icd9 ),
20
                ('EF', EFTransformer('all', 1, None)),
21
                ('EF', EFTransformer('mean', 5, None)),
22
                ('EF', EFTransformer('max', 5, None)),
23
                ('LBBB', LBBBTransformer()),
24
                #('SR', SinusRhythmTransformer()),
25
                #('Car_Doc2Vec', Doc2Vec_Note_Transformer('Car', 'doc2vec_models/car_1.model', 10, dbow_file='doc2vec_models/car_dbow.model'))
26
               # ('QRS', QRSTransformer('all', 1, None)),#Bugs with QRS
27
                ('car_ngram', FeaturePipeline([
28
                    ('notes_car', GetConcatenatedNotesTransformer(note_type='Car',look_back_months=12)),
29
                    ('ngram_car', CountVectorizer(ngram_range=(2, 2), min_df=.05))
30
                ]))
31
                #('Car', FeaturePipeline([
32
                #    ('notes_transformer_car', GetConcatenatedNotesTransformer('Car')),
33
                #    ('tfidf', car_tfidf)
34
                #])),
35
                #('Lno', FeaturePipeline([
36
                #    ('notes_transformer_lno', GetConcatenatedNotesTransformer('Lno')),
37
                #    ('tfidf', lno_tfidf)
38
                #])),
39
                #('Enc', enc),
40
                #('Labs_Counts',FeaturePipeline([
41
                #    ('labs_counts_transformer', GetLabsCountsDictTransformer()),
42
                #    ('dict_vectorizer', DictVectorizer())
43
                #])),
44
                #('Labs_Low_Counts',FeaturePipeline([
45
                #    ('labs_low_counts_transformer', GetLabsLowCountsDictTransformer()),
46
                #    ('dict_vectorizer', DictVectorizer())
47
                #])),
48
                #('Labs_High_Counts', FeaturePipeline([
49
                #    ('labs_high_counts_transformer', GetLabsHighCountsDictTransformer()),
50
                #    ('dict_vectorizer', DictVectorizer())
51
                #])),
52
                #('Labs_Latest_Low', FeaturePipeline([
53
                #    ('labs_latest_low_transformer', GetLabsLatestLowDictTransformer()),
54
                #    ('dict_vectorizer', DictVectorizer())
55
                #])),
56
                #('Labs_Latest_High',FeaturePipeline([
57
                #    ('labs_latest_high_transformer', GetLabsLatestHighDictTransformer()),
58
                #    ('dict_vectorizer', DictVectorizer())
59
                #])),
60
               # ('Labs_History', FeaturePipeline([
61
               #     ('labs_history_transformer', GetLabsHistoryDictTransformer([1])),
62
               #     ('dict_vectorizer', DictVectorizer())
63
               # ])),
64
            ])
65
66
67
    if len(sys.argv) > 1 and unicode(sys.argv[1]).isnumeric():
68
        data_size = min(906, int(sys.argv[1]))
69
    else:
70
        data_size = 25
71
72
    if len(sys.argv) > 2 and unicode(sys.argv[2]).isnumeric():
73
        num_cv_splits = int(sys.argv[2])
74
    else:
75
        num_cv_splits = 2
76
77
    method = 'lr'
78
    #method = 'svm'
79
80
    show_progress = True
81
82
    test_model(features, data_size, num_cv_splits, method, show_progress)
83
84
if __name__ == '__main__':
85
86
    # Configure logging
87
    logger = logging.getLogger("DaemonLog")
88
    logger.setLevel(logging.INFO)
89
90
    out = logging.StreamHandler(sys.stdout)
91
    out.setLevel(logging.INFO)
92
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
93
    out.setFormatter(formatter)
94
95
    logger.addHandler(out)
96
    main()