Switch to unified view

a b/analysis/ml/evaluate_model.py
1
import sys
2
import itertools
3
import pandas as pd
4
from sklearn.model_selection import (StratifiedKFold, cross_val_predict, 
5
                                     GridSearchCV, ParameterGrid, train_test_split)
6
from sklearn.base import clone
7
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, make_scorer
8
from imblearn.pipeline import Pipeline,make_pipeline
9
from metrics import balanced_accuracy
10
#from imblearn.under_sampling import NearMiss
11
from quartile_exact_match import QuartileExactMatch
12
import warnings
13
import time
14
from tempfile import mkdtemp
15
from shutil import rmtree
16
from sklearn.externals.joblib import Memory
17
from read_file import read_file
18
from utils import feature_importance, compute_imp_score, roc
19
import pdb
20
import numpy as np
21
22
def evaluate_model(dataset, save_file, random_state, clf, clf_name, hyper_params, 
23
                   longitudinal=False,rare=True):
24
25
    print('reading data...',end='')
26
    features, labels, pt_ids, feature_names, zfile = read_file(dataset,longitudinal,rare)
27
    print('done.',len(labels),'samples,',np.sum(labels==1),'cases,',features.shape[1],'features')
28
    if 'Feat' in clf_name:
29
        #set feature names
30
        clf.feature_names = ','.join(feature_names).encode()
31
    n_splits=10
32
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True,random_state=random_state)
33
34
    scoring = make_scorer(balanced_accuracy)
35
 
36
    ### 
37
    # controls matching on age and sex
38
    ###
39
    idx_age = np.argmax(feature_names == 'age')
40
    idx_sex = np.argmax(feature_names == 'SEX')
41
42
    #sampler = NearMiss(random_state=random_state, return_indices=True)
43
    sampler = QuartileExactMatch(quart_locs=[idx_age],exact_locs = [idx_sex],
44
                                 random_state=random_state)
45
     
46
    print('sampling data...',end='')
47
    X,y,sidx = sampler.fit_sample(features,labels)
48
    print('sampled data contains',np.sum(y==1),'cases',np.sum(y==0),'controls')
49
    ### 
50
    # split into train/test 
51
    ###
52
    X_train, X_test, y_train, y_test, sidx_train, sidx_test = (
53
            train_test_split(X, y, sidx,
54
                             train_size=0.5,
55
                             test_size=0.5,
56
                             random_state=random_state))
57
58
    # X,y,sidx = sampler.fit_sample(features[train_idx],labels[train_idx])
59
    if len(hyper_params) > 0:
60
        param_grid = list(ParameterGrid(hyper_params))
61
        #clone estimators
62
        Clfs = [clone(clf).set_params(**p) for p in param_grid]
63
        # fit with hyperparameter optimization 
64
        cv_scores = np.zeros((len(param_grid),10))              # cross validated scores
65
        cv_preds = np.zeros((len(param_grid),len(y_train)))      # cross validated predictions
66
        cv_probs = np.zeros((len(param_grid),len(y_train)))      # cross validated probabilities
67
        FI = np.zeros((len(param_grid),features.shape[1]))     # cross validated, permuted feature importance
68
        FI_internal = np.zeros((len(param_grid),features.shape[1]))     # cross validated feature importance
69
70
        ###########
71
        # this is a manual version of 10-fold cross validation with hyperparameter tuning
72
        t0 = time.process_time()
73
        for j,(train_idx, val_idx) in enumerate(cv.split(X_train,y_train)):
74
            print('fold',j)
75
76
            for i,est in enumerate(Clfs):
77
                print('training',type(est).__name__,i+1,'of',len(Clfs))
78
                if 'Feat' in clf_name:
79
                    est.logfile = (est.logfile.decode().split('.log')[0] + '.log.param' + str(i)
80
                                   + '.cv' + str(j)).encode()           
81
                ##########
82
                # fit model
83
                ##########
84
                if longitudinal:
85
                    est.fit(X_train[train_idx],y_train[train_idx],
86
                            zfile,pt_ids[sidx_train[train_idx]])
87
                else:
88
                    est.fit(X_train[train_idx],y_train[train_idx])
89
                
90
                ##########
91
                # get predictions
92
                ##########
93
                print('getting validation predictions...')
94
                if longitudinal:
95
                    # cv_preds[i,val_idx] = est.predict(X_train[val_idx], 
96
                    #                                    zfile,pt_ids[sidx_train[train_idx]])
97
                    if getattr(clf, "predict_proba", None):
98
                        cv_probs[i,val_idx] = est.predict_proba(X_train[val_idx],
99
                                                                 zfile,
100
                                                                 pt_ids[sidx_train[train_idx]])[:,1]
101
                    elif getattr(clf, "decision_function", None):
102
                        cv_probs[i,val_idx] = est.decision_function(X_train[val_idx],
103
                                                                 zfile,
104
                                                                 pt_ids[sidx_train[train_idx]])
105
                else:
106
                    # cv_preds[i,val_idx] = est.predict(X_train[val_idx])
107
                    if getattr(clf, "predict_proba", None):
108
                        cv_probs[i,val_idx] = est.predict_proba(X_train[val_idx])[:,1]
109
                    elif getattr(clf, "decision_function", None):
110
                        cv_probs[i,val_idx] = est.decision_function(X_train[val_idx])
111
                
112
                ##########
113
                # scores
114
                ##########
115
                cv_scores[i,j] = roc_auc_score(y_train[val_idx], cv_probs[i,val_idx])
116
117
        runtime = time.process_time() - t0
118
        ###########
119
        
120
        print('gridsearch finished in',runtime,'seconds') 
121
       
122
        ##########
123
        # get best model and its information
124
        mean_cv_scores = [np.mean(s) for s in cv_scores]
125
        best_clf = Clfs[np.argmax(mean_cv_scores)]
126
        ##########
127
    else:
128
        print('skipping hyperparameter tuning')
129
        best_clf = clf  # this option is for skipping model tuning
130
        t0 = time.process_time()
131
132
133
    print('fitting tuned model to all training data...')
134
    if longitudinal:
135
        best_clf.fit(X_train, y_train, zfile, pt_ids[sidx_train])
136
    else:
137
        best_clf.fit(X_train,y_train)
138
139
    if len(hyper_params)== 0: 
140
        runtime = time.process_time() - t0
141
    # cv_predictions = cv_preds[np.argmax(mean_cv_scores)]
142
    # cv_probabilities = cv_probs[np.argmax(mean_cv_scores)]
143
    if not longitudinal:
144
        # internal feature importances
145
        cv_FI_int = compute_imp_score(best_clf,clf_name,X_train, y_train,random_state,perm=False)
146
        # cv_FI_int = FI_internal[np.argmax(mean_cv_scores)]
147
        # permutation importances
148
        FI = compute_imp_score(best_clf, clf_name, X_test, y_test, random_state, perm=True)
149
        
150
    ##########
151
    # metrics: test the best classifier on the held-out test set 
152
    print('getting test predictions...')
153
    if longitudinal:
154
155
        print('best_clf.predict(X_test, zfile, pt_ids[sidx_test])')
156
        test_predictions = best_clf.predict(X_test, zfile, pt_ids[sidx_test])
157
        if getattr(clf, "predict_proba", None):
158
            print('best_clf.predict_proba(X_test, zfile, pt_ids[sidx_test])')
159
            test_probabilities = best_clf.predict_proba(X_test,
160
                                                 zfile,
161
                                                 pt_ids[sidx_test])[:,1]
162
        elif getattr(clf, "decision_function", None):
163
            test_probabilities = best_clf.decision_function(X_test,
164
                                                     zfile,
165
                                                     pt_ids[sidx_test])
166
    else:
167
        test_predictions = best_clf.predict(X_test)
168
        if getattr(clf, "predict_proba", None):
169
            test_probabilities = best_clf.predict_proba(X_test)[:,1]
170
        elif getattr(clf, "decision_function", None):
171
            test_probabilities = best_clf.decision_function(X_test)
172
173
    # # write cv_pred and cv_prob to file
174
    # df = pd.DataFrame({'cv_prediction':cv_predictions,'cv_probability':cv_probabilities,
175
    #                    'pt_id':pt_ids})
176
    # df.to_csv(save_file.split('.csv')[0] + '_' + str(random_state) + '.cv_predictions',index=None)
177
    accuracy = accuracy_score(y_test, test_predictions)
178
    macro_f1 = f1_score(y_test, test_predictions, average='macro')
179
    bal_acc = balanced_accuracy(y_test, test_predictions)
180
    roc_auc = roc_auc_score(y_test, test_probabilities)
181
182
    ##########
183
    # save results to file
184
    print('saving results...')
185
    param_string = ','.join(['{}={}'.format(p, v) 
186
                             for p,v in best_clf.get_params().items() 
187
                             if p!='feature_names']).replace('\n','').replace(' ','')
188
189
    out_text = '\t'.join([dataset.split('/')[-1],
190
                          clf_name,
191
                          param_string,
192
                          str(random_state), 
193
                          str(accuracy),
194
                          str(macro_f1),
195
                          str(bal_acc),
196
                          str(roc_auc),
197
                          str(runtime)])
198
    print(out_text)
199
    with open(save_file, 'a') as out:
200
        out.write(out_text+'\n')
201
    sys.stdout.flush()
202
203
    print('saving feature importance') 
204
    # write feature importances
205
    if not longitudinal:
206
        feature_importance(save_file, best_clf, feature_names, X_test, y_test, random_state, 
207
                           clf_name, param_string, cv_FI_int,perm=False)
208
        feature_importance(save_file, best_clf, feature_names, X_test, y_test, random_state, 
209
                           clf_name, param_string, FI,perm=True)
210
    # write roc curves
211
    print('saving roc') 
212
    roc(save_file, y_test, test_probabilities, random_state, clf_name,param_string)
213
214
    return best_clf