[8d2107]: / michael_tester.py

Download this file

125 lines (108 with data), 5.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import sys
import cProfile
from model_tester import FeaturePipeline, test_model
from sklearn.pipeline import FeatureUnion
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from baseline_transformer import GetConcatenatedNotesTransformer, GetLatestNotesTransformer, GetEncountersFeaturesTransformer, GetLabsCountsDictTransformer, GetLabsLowCountsDictTransformer, GetLabsHighCountsDictTransformer, GetLabsLatestHighDictTransformer, GetLabsLatestLowDictTransformer, GetLabsHistoryDictTransformer
from extract_data import get_doc_rel_dates, get_operation_date, get_ef_values
from extract_data import get_operation_date, is_note_doc, get_date_key
from icd_transformer import ICD9_Transformer
from value_extractor_transformer import EFTransformer, LBBBTransformer, SinusRhythmTransformer, QRSTransformer, NYHATransformer, NICMTransformer
from language_processing import parse_date
def main():
transformer_list = []
regex_features = True
icd9_features = False
labs_features = False
text_features = False
if regex_features:
transformer_list += [
('EF', EFTransformer('all', 1, None)),
('EF', EFTransformer('mean', 5, None)),
('EF', EFTransformer('max', 5, None)),
('LBBB', LBBBTransformer(30*3)),
('SR', SinusRhythmTransformer(30*3)),
('NYHA', NYHATransformer(30*3)),
('NICM', NICMTransformer(30*3)),
('QRS', QRSTransformer('all', 1, None)),
('QRS', QRSTransformer('mean', 5, None)),
]
if icd9_features:
transformer_list += [
('Dia', ICD9_Transformer())
]
if text_features:
transformer_list += [
('Car', FeaturePipeline([
('notes_transformer_car', GetConcatenatedNotesTransformer('Car')),
('tfidf', TfidfTransformer())
])),
('Lno', FeaturePipeline([
('notes_transformer_lno', GetConcatenatedNotesTransformer('Lno')),
('tfidf', TfidfTransformer)
]))
]
if labs_features:
transformer_list += [
('Enc', GetEncountersFeaturesTransformer(5)),
('Labs_Counts',FeaturePipeline([
('labs_counts_transformer', GetLabsCountsDictTransformer()),
('dict_vectorizer', DictVectorizer())
])),
('Labs_Low_Counts',FeaturePipeline([
('labs_low_counts_transformer', GetLabsLowCountsDictTransformer()),
('dict_vectorizer', DictVectorizer())
])),
('Labs_High_Counts', FeaturePipeline([
('labs_high_counts_transformer', GetLabsHighCountsDictTransformer()),
('dict_vectorizer', DictVectorizer())
])),
('Labs_Latest_Low', FeaturePipeline([
('labs_latest_low_transformer', GetLabsLatestLowDictTransformer()),
('dict_vectorizer', DictVectorizer())
])),
('Labs_Latest_High',FeaturePipeline([
('labs_latest_high_transformer', GetLabsLatestHighDictTransformer()),
('dict_vectorizer', DictVectorizer())
])),
('Labs_History', FeaturePipeline([
('labs_history_transformer', GetLabsHistoryDictTransformer([1])),
('dict_vectorizer', DictVectorizer())
]))
]
features = FeatureUnion(transformer_list)
if len(sys.argv) > 1 and unicode(sys.argv[1]).isnumeric():
data_size = min(int(sys.argv[1]), 906)
else:
data_size = 25
if len(sys.argv) > 2 and unicode(sys.argv[2]).isnumeric():
num_cv_splits = int(sys.argv[2])
else:
num_cv_splits = 5
print "Data size: " + str(data_size)
print "CV splits: " + str(num_cv_splits)
if len(sys.argv) > 3:
method = sys.argv[3]
else:
method = 'adaboost'
#method = 'lr'
#method = 'svm'
method = 'adaboost'
#method = 'cdm'
model_args = dict()
if method in ['lr', 'svm']:
if len(sys.argv) > 4 and unicode(sys.argv[4]).isnumeric():
model_args['regularization'] = float(sys.argv[4])
else:
model_args['regularization'] = 0.
if method == 'adaboost':
if len(sys.argv) > 4 and unicode(sys.argv[4]).isnumeric():
model_args['n_estimators'] = int(sys.argv[4])
else:
model_args['n_estimators'] = 50
show_progress = True
print 'Method:', method
test_model(features, data_size, num_cv_splits, method, show_progress, model_args)
if __name__ == '__main__':
main()