a b/python/train_SVM.py
1
#!/usr/bin/env python
2
3
"""
4
train_SVM.py
5
    
6
VARPA, University of Coruna
7
Mondejar Guerra, Victor M.
8
23 Oct 2017
9
"""
10
11
from load_MITBIH import *
12
from evaluation_AAMI import *
13
from aggregation_voting_strategies import *
14
from oversampling import *
15
from cross_validation import *
16
from feature_selection import *
17
18
import sklearn
19
from sklearn.externals import joblib
20
from sklearn.preprocessing import StandardScaler
21
from sklearn import svm
22
23
from sklearn import decomposition
24
25
import os
26
27
def create_svm_model_name(model_svm_path, winL, winR, do_preprocess, 
28
    maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection, 
29
    oversamp_method, leads_flag, reduced_DS, pca_k, delimiter):
30
31
    if reduced_DS == True:
32
        model_svm_path = model_svm_path + delimiter + 'exp_2'
33
34
    if leads_flag[0] == 1:
35
        model_svm_path = model_svm_path + delimiter + 'MLII'
36
    
37
    if leads_flag[1] == 1:
38
        model_svm_path = model_svm_path + delimiter + 'V1'
39
40
    if oversamp_method: 
41
        model_svm_path = model_svm_path + delimiter + oversamp_method
42
43
    if feature_selection:
44
        model_svm_path = model_svm_path + delimiter + feature_selection
45
46
    if do_preprocess:
47
        model_svm_path = model_svm_path + delimiter + 'rm_bsln'
48
49
    if maxRR:
50
        model_svm_path = model_svm_path + delimiter + 'maxRR'
51
52
    if use_RR:
53
        model_svm_path = model_svm_path + delimiter + 'RR'
54
    
55
    if norm_RR:
56
        model_svm_path = model_svm_path + delimiter + 'norm_RR'
57
    
58
    for descp in compute_morph:
59
        model_svm_path = model_svm_path + delimiter + descp
60
    
61
    if use_weight_class:
62
        model_svm_path = model_svm_path + delimiter + 'weighted'
63
64
    if pca_k > 0:
65
        model_svm_path = model_svm_path + delimiter + 'pca_' + str(pca_k)
66
67
    return model_svm_path
68
69
70
# Eval the SVM model and export the results
71
def eval_model(svm_model, features, labels, multi_mode, voting_strategy, output_path, C_value, gamma_value, DS):
72
    if multi_mode == 'ovo':
73
        decision_ovo        = svm_model.decision_function(features)
74
        
75
        if voting_strategy == 'ovo_voting':
76
            predict_ovo, counter    = ovo_voting(decision_ovo, 4)
77
78
        elif voting_strategy == 'ovo_voting_both':
79
            predict_ovo, counter    = ovo_voting_both(decision_ovo, 4)
80
81
        elif voting_strategy == 'ovo_voting_exp':
82
            predict_ovo, counter    = ovo_voting_exp(decision_ovo, 4)
83
84
        # svm_model.predict_log_proba  svm_model.predict_proba   svm_model.predict ...
85
        perf_measures = compute_AAMI_performance_measures(predict_ovo, labels)
86
87
    """
88
    elif multi_mode == 'ovr':cr
89
        decision_ovr = svm_model.decision_function(features)
90
        predict_ovr = svm_model.predict(features)
91
        perf_measures = compute_AAMI_performance_measures(predict_ovr, labels)
92
    """
93
94
    # Write results and also predictions on DS2
95
    if not os.path.exists(output_path):
96
        os.makedirs(output_path)
97
98
    if gamma_value != 0.0:
99
        write_AAMI_results( perf_measures, output_path + '/' + DS + 'C_' + str(C_value) + 'g_' + str(gamma_value) + 
100
            '_score_Ijk_' + str(format(perf_measures.Ijk, '.2f')) + '_' + voting_strategy + '.txt')
101
    else:
102
        write_AAMI_results( perf_measures, output_path + '/' + DS + 'C_' + str(C_value) + 
103
            '_score_Ijk_' + str(format(perf_measures.Ijk, '.2f')) + '_' + voting_strategy + '.txt')
104
    
105
    # Array to .csv
106
    if multi_mode == 'ovo':
107
        if gamma_value != 0.0:
108
            np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + 'g_' + str(gamma_value) + 
109
                '_decision_ovo.csv', decision_ovo)
110
            np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + 'g_' + str(gamma_value) + 
111
                '_predict_' + voting_strategy + '.csv', predict_ovo.astype(int), '%.0f') 
112
        else:
113
            np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) +
114
                '_decision_ovo.csv', decision_ovo)
115
            np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + 
116
                '_predict_' + voting_strategy + '.csv', predict_ovo.astype(int), '%.0f') 
117
118
    elif multi_mode == 'ovr':
119
        np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) +
120
            '_decision_ovr.csv', prob_ovr)
121
        np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + 
122
            '_predict_' + voting_strategy + '.csv', predict_ovr.astype(int), '%.0f') 
123
124
    print("Results writed at " + output_path + '/' + DS + 'C_' + str(C_value))
125
126
127
128
def create_oversamp_name(reduced_DS, do_preprocess, compute_morph, winL, winR, maxRR, use_RR, norm_RR, pca_k):
129
    oversamp_features_pickle_name = ''
130
    if reduced_DS:
131
        oversamp_features_pickle_name += '_reduced_'
132
        
133
    if do_preprocess:
134
        oversamp_features_pickle_name += '_rm_bsline'
135
136
    if maxRR:
137
        oversamp_features_pickle_name += '_maxRR'
138
139
    if use_RR:
140
        oversamp_features_pickle_name += '_RR'
141
    
142
    if norm_RR:
143
        oversamp_features_pickle_name += '_norm_RR'
144
145
    for descp in compute_morph:
146
        oversamp_features_pickle_name += '_' + descp
147
148
    if pca_k > 0:
149
        oversamp_features_pickle_name += '_pca_' + str(pca_k)
150
    
151
    oversamp_features_pickle_name += '_wL_' + str(winL) + '_wR_' + str(winR)
152
    
153
    return oversamp_features_pickle_name
154
155
156
157
def main(multi_mode='ovo', winL=90, winR=90, do_preprocess=True, use_weight_class=True, 
158
    maxRR=True, use_RR=True, norm_RR=True, compute_morph={''}, oversamp_method = '', pca_k = '', feature_selection = '', do_cross_val = '', C_value = 0.001, gamma_value = 0.0, reduced_DS = False, leads_flag = [1,0]):
159
    print("Runing train_SVM.py!")
160
161
    db_path = '/home/mondejar/dataset/ECG/mitdb/m_learning/scikit/'
162
    
163
    # Load train data 
164
    [tr_features, tr_labels, tr_patient_num_beats] = load_mit_db('DS1', winL, winR, do_preprocess,
165
        maxRR, use_RR, norm_RR, compute_morph, db_path, reduced_DS, leads_flag)
166
167
    # Load Test data
168
    [eval_features, eval_labels, eval_patient_num_beats] = load_mit_db('DS2', winL, winR, do_preprocess, 
169
        maxRR, use_RR, norm_RR, compute_morph, db_path, reduced_DS, leads_flag)
170
    if reduced_DS == True:
171
        np.savetxt('mit_db/' + 'exp_2_' + 'DS2_labels.csv', eval_labels.astype(int), '%.0f') 
172
    else:
173
        np.savetxt('mit_db/' + 'DS2_labels.csv', eval_labels.astype(int), '%.0f') 
174
175
    #if reduced_DS == True:
176
    #    np.savetxt('mit_db/' + 'exp_2_' + 'DS1_labels.csv', tr_labels.astype(int), '%.0f') 
177
    #else:
178
    #np.savetxt('mit_db/' + 'DS1_labels.csv', tr_labels.astype(int), '%.0f') 
179
  
180
    ##############################################################
181
    # 0) TODO if feature_Selection:
182
    # before oversamp!!?????
183
184
    # TODO perform normalization before the oversampling?
185
    if oversamp_method:
186
        # Filename
187
        oversamp_features_pickle_name = create_oversamp_name(reduced_DS, do_preprocess, compute_morph, winL, winR, maxRR, use_RR, norm_RR, pca_k)
188
189
        # Do oversampling
190
        tr_features, tr_labels = perform_oversampling(oversamp_method, db_path + 'oversamp/python_mit', oversamp_features_pickle_name, tr_features, tr_labels)
191
192
    # Normalization of the input data
193
    # scaled: zero mean unit variance ( z-score )
194
    scaler = StandardScaler()
195
    scaler.fit(tr_features)
196
    tr_features_scaled = scaler.transform(tr_features)
197
198
    # scaled: zero mean unit variance ( z-score )
199
    eval_features_scaled = scaler.transform(eval_features)
200
    ##############################################################
201
    # 0) ????????????? feature_Selection: also after Oversampling???
202
    if feature_selection:
203
        print("Runing feature selection")
204
        best_features = 7
205
        tr_features_scaled, features_index_sorted  = run_feature_selection(tr_features_scaled, tr_labels, feature_selection, best_features)
206
        eval_features_scaled = eval_features_scaled[:, features_index_sorted[0:best_features]]
207
    # 1)
208
    if pca_k > 0:
209
210
        # Load if exists??
211
        # NOTE PCA do memory error!
212
213
        # NOTE 11 Enero: TEST WITH IPCA!!!!!!
214
        start = time.time()
215
        
216
        print("Runing IPCA " + str(pca_k) + "...")
217
218
        # Run PCA
219
        IPCA = sklearn.decomposition.IncrementalPCA(pca_k, batch_size=pca_k) # gamma_pca
220
221
        #tr_features_scaled = KPCA.fit_transform(tr_features_scaled) 
222
        IPCA.fit(tr_features_scaled) 
223
224
        # Apply PCA on test data!
225
        tr_features_scaled = IPCA.transform(tr_features_scaled)
226
        eval_features_scaled = IPCA.transform(eval_features_scaled)
227
228
        """
229
        print("Runing TruncatedSVD (singular value decomposition (SVD)!!!) (alternative to PCA) " + str(pca_k) + "...")
230
231
        svd = decomposition.TruncatedSVD(n_components=pca_k, algorithm='arpack')
232
        svd.fit(tr_features_scaled)
233
        tr_features_scaled = svd.transform(tr_features_scaled)
234
        eval_features_scaled = svd.transform(eval_features_scaled)
235
        
236
        """
237
        end = time.time()
238
239
        print("Time runing IPCA (rbf): " + str(format(end - start, '.2f')) + " sec" )
240
    ##############################################################
241
    # 2) Cross-validation: 
242
243
    if do_cross_val:
244
        print("Runing cross val...")
245
        start = time.time()
246
247
        # TODO Save data over the k-folds and ranked by the best average values in separated files   
248
        perf_measures_path = create_svm_model_name('/home/mondejar/Dropbox/ECG/code/ecg_classification/python/results/' + multi_mode, winL, winR, do_preprocess, 
249
        maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection, oversamp_method, leads_flag, reduced_DS,  pca_k, '/')
250
251
        # TODO implement this method! check to avoid NaN scores....
252
253
        if do_cross_val == 'pat_cv': # Cross validation with one fold per patient
254
            cv_scores, c_values =  run_cross_val(tr_features_scaled, tr_labels, tr_patient_num_beats, do_cross_val, len(tr_patient_num_beats))
255
256
            if not os.path.exists(perf_measures_path):
257
                os.makedirs(perf_measures_path)
258
            np.savetxt(perf_measures_path + '/cross_val_k-pat_cv_F_score.csv', (c_values, cv_scores.astype(float)), "%f") 
259
260
        elif do_cross_val == 'beat_cv': # cross validation by class id samples
261
            k_folds = {5}
262
            for k in k_folds:
263
                ijk_scores, c_values = run_cross_val(tr_features_scaled, tr_labels, tr_patient_num_beats, do_cross_val, k)
264
                # TODO Save data over the k-folds and ranked by the best average values in separated files   
265
                perf_measures_path = create_svm_model_name('/home/mondejar/Dropbox/ECG/code/ecg_classification/python/results/' + multi_mode, winL, winR, do_preprocess, 
266
                maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection, oversamp_method, leads_flag, reduced_DS,  pca_k, '/')
267
268
                if not os.path.exists(perf_measures_path):
269
                    os.makedirs(perf_measures_path)
270
                np.savetxt(perf_measures_path + '/cross_val_k-' + str(k) + '_Ijk_score.csv', (c_values, ijk_scores.astype(float)), "%f") 
271
            
272
            end = time.time()
273
            print("Time runing Cross Validation: " + str(format(end - start, '.2f')) + " sec" )
274
    else:
275
276
        ################################################################################################
277
        # 3) Train SVM model
278
279
        # TODO load best params from cross validation!
280
        
281
        use_probability = False
282
283
        model_svm_path = db_path + 'svm_models/' + multi_mode + '_rbf'
284
285
        model_svm_path = create_svm_model_name(model_svm_path, winL, winR, do_preprocess,
286
            maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection,
287
            oversamp_method, leads_flag, reduced_DS, pca_k, '_')
288
289
        if gamma_value != 0.0:
290
            model_svm_path = model_svm_path + '_C_' +  str(C_value) + '_g_' +  str(gamma_value) +'.joblib.pkl'
291
        else:
292
            model_svm_path = model_svm_path + '_C_' +  str(C_value) + '.joblib.pkl'
293
294
        print("Training model on MIT-BIH DS1: " + model_svm_path + "...")
295
296
        if os.path.isfile(model_svm_path):
297
            # Load the trained model!
298
            svm_model = joblib.load(model_svm_path)
299
300
        else:
301
            class_weights = {}
302
            for c in range(4):
303
                class_weights.update({c:len(tr_labels) / float(np.count_nonzero(tr_labels == c))})
304
305
            #class_weight='balanced', 
306
            if gamma_value != 0.0: # NOTE 0.0 means 1/n_features default value
307
                svm_model = svm.SVC(C=C_value, kernel='rbf', degree=3, gamma=gamma_value,  
308
                    coef0=0.0, shrinking=True, probability=use_probability, tol=0.001, 
309
                    cache_size=200, class_weight=class_weights, verbose=False, 
310
                    max_iter=-1, decision_function_shape=multi_mode, random_state=None)
311
            else:             
312
                svm_model = svm.SVC(C=C_value, kernel='rbf', degree=3, gamma='auto', 
313
                    coef0=0.0, shrinking=True, probability=use_probability, tol=0.001, 
314
                    cache_size=200, class_weight=class_weights, verbose=False, 
315
                    max_iter=-1, decision_function_shape=multi_mode, random_state=None)
316
            
317
            # Let's Train!
318
319
            start = time.time()
320
            svm_model.fit(tr_features_scaled, tr_labels) 
321
            end = time.time()
322
            # TODO assert that the class_ID appears with the desired order, 
323
            # with the goal of ovo make the combinations properly
324
            print("Trained completed!\n\t" + model_svm_path + "\n \
325
                \tTime required: " + str(format(end - start, '.2f')) + " sec" )
326
327
            # Export model: save/write trained SVM model
328
            joblib.dump(svm_model, model_svm_path)
329
330
            # TODO Export StandardScaler()
331
        
332
        #########################################################################
333
        # 4) Test SVM model
334
        print("Testing model on MIT-BIH DS2: " + model_svm_path + "...")
335
336
        ############################################################################################################
337
        # EVALUATION
338
        ############################################################################################################
339
340
        # Evaluate the model on the training data
341
        perf_measures_path = create_svm_model_name('/home/mondejar/Dropbox/ECG/code/ecg_classification/python/results/' + multi_mode, winL, winR, do_preprocess, 
342
            maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection, oversamp_method, leads_flag, reduced_DS, pca_k, '/')
343
344
        # ovo_voting:
345
        # Simply add 1 to the win class
346
        print("Evaluation on DS1 ...")
347
        eval_model(svm_model, tr_features_scaled, tr_labels, multi_mode, 'ovo_voting', perf_measures_path, C_value, gamma_value, 'Train_')
348
349
        # Let's test new data!
350
        print("Evaluation on DS2 ...")   
351
        eval_model(svm_model, eval_features_scaled, eval_labels, multi_mode, 'ovo_voting', perf_measures_path, C_value, gamma_value, '')
352
353
354
        # ovo_voting_exp:
355
        # Consider the post prob adding to both classes
356
        print("Evaluation on DS1 ...")
357
        eval_model(svm_model, tr_features_scaled, tr_labels, multi_mode, 'ovo_voting_exp', perf_measures_path, C_value, gamma_value, 'Train_')
358
359
        # Let's test new data!
360
        print("Evaluation on DS2 ...")   
361
        eval_model(svm_model, eval_features_scaled, eval_labels, multi_mode, 'ovo_voting_exp', perf_measures_path, C_value, gamma_value, '')