|
a |
|
b/analysis/ml/evaluate_model.py |
|
|
1 |
import sys |
|
|
2 |
import itertools |
|
|
3 |
import pandas as pd |
|
|
4 |
from sklearn.model_selection import (StratifiedKFold, cross_val_predict, |
|
|
5 |
GridSearchCV, ParameterGrid, train_test_split) |
|
|
6 |
from sklearn.base import clone |
|
|
7 |
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, make_scorer |
|
|
8 |
from imblearn.pipeline import Pipeline,make_pipeline |
|
|
9 |
from metrics import balanced_accuracy |
|
|
10 |
#from imblearn.under_sampling import NearMiss |
|
|
11 |
from quartile_exact_match import QuartileExactMatch |
|
|
12 |
import warnings |
|
|
13 |
import time |
|
|
14 |
from tempfile import mkdtemp |
|
|
15 |
from shutil import rmtree |
|
|
16 |
from sklearn.externals.joblib import Memory |
|
|
17 |
from read_file import read_file |
|
|
18 |
from utils import feature_importance, compute_imp_score, roc |
|
|
19 |
import pdb |
|
|
20 |
import numpy as np |
|
|
21 |
|
|
|
22 |
def evaluate_model(dataset, save_file, random_state, clf, clf_name, hyper_params, |
|
|
23 |
longitudinal=False,rare=True): |
|
|
24 |
|
|
|
25 |
print('reading data...',end='') |
|
|
26 |
features, labels, pt_ids, feature_names, zfile = read_file(dataset,longitudinal,rare) |
|
|
27 |
print('done.',len(labels),'samples,',np.sum(labels==1),'cases,',features.shape[1],'features') |
|
|
28 |
if 'Feat' in clf_name: |
|
|
29 |
#set feature names |
|
|
30 |
clf.feature_names = ','.join(feature_names).encode() |
|
|
31 |
n_splits=10 |
|
|
32 |
cv = StratifiedKFold(n_splits=n_splits, shuffle=True,random_state=random_state) |
|
|
33 |
|
|
|
34 |
scoring = make_scorer(balanced_accuracy) |
|
|
35 |
|
|
|
36 |
### |
|
|
37 |
# controls matching on age and sex |
|
|
38 |
### |
|
|
39 |
idx_age = np.argmax(feature_names == 'age') |
|
|
40 |
idx_sex = np.argmax(feature_names == 'SEX') |
|
|
41 |
|
|
|
42 |
#sampler = NearMiss(random_state=random_state, return_indices=True) |
|
|
43 |
sampler = QuartileExactMatch(quart_locs=[idx_age],exact_locs = [idx_sex], |
|
|
44 |
random_state=random_state) |
|
|
45 |
|
|
|
46 |
print('sampling data...',end='') |
|
|
47 |
X,y,sidx = sampler.fit_sample(features,labels) |
|
|
48 |
print('sampled data contains',np.sum(y==1),'cases',np.sum(y==0),'controls') |
|
|
49 |
### |
|
|
50 |
# split into train/test |
|
|
51 |
### |
|
|
52 |
X_train, X_test, y_train, y_test, sidx_train, sidx_test = ( |
|
|
53 |
train_test_split(X, y, sidx, |
|
|
54 |
train_size=0.5, |
|
|
55 |
test_size=0.5, |
|
|
56 |
random_state=random_state)) |
|
|
57 |
|
|
|
58 |
# X,y,sidx = sampler.fit_sample(features[train_idx],labels[train_idx]) |
|
|
59 |
if len(hyper_params) > 0: |
|
|
60 |
param_grid = list(ParameterGrid(hyper_params)) |
|
|
61 |
#clone estimators |
|
|
62 |
Clfs = [clone(clf).set_params(**p) for p in param_grid] |
|
|
63 |
# fit with hyperparameter optimization |
|
|
64 |
cv_scores = np.zeros((len(param_grid),10)) # cross validated scores |
|
|
65 |
cv_preds = np.zeros((len(param_grid),len(y_train))) # cross validated predictions |
|
|
66 |
cv_probs = np.zeros((len(param_grid),len(y_train))) # cross validated probabilities |
|
|
67 |
FI = np.zeros((len(param_grid),features.shape[1])) # cross validated, permuted feature importance |
|
|
68 |
FI_internal = np.zeros((len(param_grid),features.shape[1])) # cross validated feature importance |
|
|
69 |
|
|
|
70 |
########### |
|
|
71 |
# this is a manual version of 10-fold cross validation with hyperparameter tuning |
|
|
72 |
t0 = time.process_time() |
|
|
73 |
for j,(train_idx, val_idx) in enumerate(cv.split(X_train,y_train)): |
|
|
74 |
print('fold',j) |
|
|
75 |
|
|
|
76 |
for i,est in enumerate(Clfs): |
|
|
77 |
print('training',type(est).__name__,i+1,'of',len(Clfs)) |
|
|
78 |
if 'Feat' in clf_name: |
|
|
79 |
est.logfile = (est.logfile.decode().split('.log')[0] + '.log.param' + str(i) |
|
|
80 |
+ '.cv' + str(j)).encode() |
|
|
81 |
########## |
|
|
82 |
# fit model |
|
|
83 |
########## |
|
|
84 |
if longitudinal: |
|
|
85 |
est.fit(X_train[train_idx],y_train[train_idx], |
|
|
86 |
zfile,pt_ids[sidx_train[train_idx]]) |
|
|
87 |
else: |
|
|
88 |
est.fit(X_train[train_idx],y_train[train_idx]) |
|
|
89 |
|
|
|
90 |
########## |
|
|
91 |
# get predictions |
|
|
92 |
########## |
|
|
93 |
print('getting validation predictions...') |
|
|
94 |
if longitudinal: |
|
|
95 |
# cv_preds[i,val_idx] = est.predict(X_train[val_idx], |
|
|
96 |
# zfile,pt_ids[sidx_train[train_idx]]) |
|
|
97 |
if getattr(clf, "predict_proba", None): |
|
|
98 |
cv_probs[i,val_idx] = est.predict_proba(X_train[val_idx], |
|
|
99 |
zfile, |
|
|
100 |
pt_ids[sidx_train[train_idx]])[:,1] |
|
|
101 |
elif getattr(clf, "decision_function", None): |
|
|
102 |
cv_probs[i,val_idx] = est.decision_function(X_train[val_idx], |
|
|
103 |
zfile, |
|
|
104 |
pt_ids[sidx_train[train_idx]]) |
|
|
105 |
else: |
|
|
106 |
# cv_preds[i,val_idx] = est.predict(X_train[val_idx]) |
|
|
107 |
if getattr(clf, "predict_proba", None): |
|
|
108 |
cv_probs[i,val_idx] = est.predict_proba(X_train[val_idx])[:,1] |
|
|
109 |
elif getattr(clf, "decision_function", None): |
|
|
110 |
cv_probs[i,val_idx] = est.decision_function(X_train[val_idx]) |
|
|
111 |
|
|
|
112 |
########## |
|
|
113 |
# scores |
|
|
114 |
########## |
|
|
115 |
cv_scores[i,j] = roc_auc_score(y_train[val_idx], cv_probs[i,val_idx]) |
|
|
116 |
|
|
|
117 |
runtime = time.process_time() - t0 |
|
|
118 |
########### |
|
|
119 |
|
|
|
120 |
print('gridsearch finished in',runtime,'seconds') |
|
|
121 |
|
|
|
122 |
########## |
|
|
123 |
# get best model and its information |
|
|
124 |
mean_cv_scores = [np.mean(s) for s in cv_scores] |
|
|
125 |
best_clf = Clfs[np.argmax(mean_cv_scores)] |
|
|
126 |
########## |
|
|
127 |
else: |
|
|
128 |
print('skipping hyperparameter tuning') |
|
|
129 |
best_clf = clf # this option is for skipping model tuning |
|
|
130 |
t0 = time.process_time() |
|
|
131 |
|
|
|
132 |
|
|
|
133 |
print('fitting tuned model to all training data...') |
|
|
134 |
if longitudinal: |
|
|
135 |
best_clf.fit(X_train, y_train, zfile, pt_ids[sidx_train]) |
|
|
136 |
else: |
|
|
137 |
best_clf.fit(X_train,y_train) |
|
|
138 |
|
|
|
139 |
if len(hyper_params)== 0: |
|
|
140 |
runtime = time.process_time() - t0 |
|
|
141 |
# cv_predictions = cv_preds[np.argmax(mean_cv_scores)] |
|
|
142 |
# cv_probabilities = cv_probs[np.argmax(mean_cv_scores)] |
|
|
143 |
if not longitudinal: |
|
|
144 |
# internal feature importances |
|
|
145 |
cv_FI_int = compute_imp_score(best_clf,clf_name,X_train, y_train,random_state,perm=False) |
|
|
146 |
# cv_FI_int = FI_internal[np.argmax(mean_cv_scores)] |
|
|
147 |
# permutation importances |
|
|
148 |
FI = compute_imp_score(best_clf, clf_name, X_test, y_test, random_state, perm=True) |
|
|
149 |
|
|
|
150 |
########## |
|
|
151 |
# metrics: test the best classifier on the held-out test set |
|
|
152 |
print('getting test predictions...') |
|
|
153 |
if longitudinal: |
|
|
154 |
|
|
|
155 |
print('best_clf.predict(X_test, zfile, pt_ids[sidx_test])') |
|
|
156 |
test_predictions = best_clf.predict(X_test, zfile, pt_ids[sidx_test]) |
|
|
157 |
if getattr(clf, "predict_proba", None): |
|
|
158 |
print('best_clf.predict_proba(X_test, zfile, pt_ids[sidx_test])') |
|
|
159 |
test_probabilities = best_clf.predict_proba(X_test, |
|
|
160 |
zfile, |
|
|
161 |
pt_ids[sidx_test])[:,1] |
|
|
162 |
elif getattr(clf, "decision_function", None): |
|
|
163 |
test_probabilities = best_clf.decision_function(X_test, |
|
|
164 |
zfile, |
|
|
165 |
pt_ids[sidx_test]) |
|
|
166 |
else: |
|
|
167 |
test_predictions = best_clf.predict(X_test) |
|
|
168 |
if getattr(clf, "predict_proba", None): |
|
|
169 |
test_probabilities = best_clf.predict_proba(X_test)[:,1] |
|
|
170 |
elif getattr(clf, "decision_function", None): |
|
|
171 |
test_probabilities = best_clf.decision_function(X_test) |
|
|
172 |
|
|
|
173 |
# # write cv_pred and cv_prob to file |
|
|
174 |
# df = pd.DataFrame({'cv_prediction':cv_predictions,'cv_probability':cv_probabilities, |
|
|
175 |
# 'pt_id':pt_ids}) |
|
|
176 |
# df.to_csv(save_file.split('.csv')[0] + '_' + str(random_state) + '.cv_predictions',index=None) |
|
|
177 |
accuracy = accuracy_score(y_test, test_predictions) |
|
|
178 |
macro_f1 = f1_score(y_test, test_predictions, average='macro') |
|
|
179 |
bal_acc = balanced_accuracy(y_test, test_predictions) |
|
|
180 |
roc_auc = roc_auc_score(y_test, test_probabilities) |
|
|
181 |
|
|
|
182 |
########## |
|
|
183 |
# save results to file |
|
|
184 |
print('saving results...') |
|
|
185 |
param_string = ','.join(['{}={}'.format(p, v) |
|
|
186 |
for p,v in best_clf.get_params().items() |
|
|
187 |
if p!='feature_names']).replace('\n','').replace(' ','') |
|
|
188 |
|
|
|
189 |
out_text = '\t'.join([dataset.split('/')[-1], |
|
|
190 |
clf_name, |
|
|
191 |
param_string, |
|
|
192 |
str(random_state), |
|
|
193 |
str(accuracy), |
|
|
194 |
str(macro_f1), |
|
|
195 |
str(bal_acc), |
|
|
196 |
str(roc_auc), |
|
|
197 |
str(runtime)]) |
|
|
198 |
print(out_text) |
|
|
199 |
with open(save_file, 'a') as out: |
|
|
200 |
out.write(out_text+'\n') |
|
|
201 |
sys.stdout.flush() |
|
|
202 |
|
|
|
203 |
print('saving feature importance') |
|
|
204 |
# write feature importances |
|
|
205 |
if not longitudinal: |
|
|
206 |
feature_importance(save_file, best_clf, feature_names, X_test, y_test, random_state, |
|
|
207 |
clf_name, param_string, cv_FI_int,perm=False) |
|
|
208 |
feature_importance(save_file, best_clf, feature_names, X_test, y_test, random_state, |
|
|
209 |
clf_name, param_string, FI,perm=True) |
|
|
210 |
# write roc curves |
|
|
211 |
print('saving roc') |
|
|
212 |
roc(save_file, y_test, test_probabilities, random_state, clf_name,param_string) |
|
|
213 |
|
|
|
214 |
return best_clf |