|
a |
|
b/python/train_SVM.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
|
|
|
3 |
""" |
|
|
4 |
train_SVM.py |
|
|
5 |
|
|
|
6 |
VARPA, University of Coruna |
|
|
7 |
Mondejar Guerra, Victor M. |
|
|
8 |
23 Oct 2017 |
|
|
9 |
""" |
|
|
10 |
|
|
|
11 |
from load_MITBIH import * |
|
|
12 |
from evaluation_AAMI import * |
|
|
13 |
from aggregation_voting_strategies import * |
|
|
14 |
from oversampling import * |
|
|
15 |
from cross_validation import * |
|
|
16 |
from feature_selection import * |
|
|
17 |
|
|
|
18 |
import sklearn |
|
|
19 |
from sklearn.externals import joblib |
|
|
20 |
from sklearn.preprocessing import StandardScaler |
|
|
21 |
from sklearn import svm |
|
|
22 |
|
|
|
23 |
from sklearn import decomposition |
|
|
24 |
|
|
|
25 |
import os |
|
|
26 |
|
|
|
27 |
def create_svm_model_name(model_svm_path, winL, winR, do_preprocess, |
|
|
28 |
maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection, |
|
|
29 |
oversamp_method, leads_flag, reduced_DS, pca_k, delimiter): |
|
|
30 |
|
|
|
31 |
if reduced_DS == True: |
|
|
32 |
model_svm_path = model_svm_path + delimiter + 'exp_2' |
|
|
33 |
|
|
|
34 |
if leads_flag[0] == 1: |
|
|
35 |
model_svm_path = model_svm_path + delimiter + 'MLII' |
|
|
36 |
|
|
|
37 |
if leads_flag[1] == 1: |
|
|
38 |
model_svm_path = model_svm_path + delimiter + 'V1' |
|
|
39 |
|
|
|
40 |
if oversamp_method: |
|
|
41 |
model_svm_path = model_svm_path + delimiter + oversamp_method |
|
|
42 |
|
|
|
43 |
if feature_selection: |
|
|
44 |
model_svm_path = model_svm_path + delimiter + feature_selection |
|
|
45 |
|
|
|
46 |
if do_preprocess: |
|
|
47 |
model_svm_path = model_svm_path + delimiter + 'rm_bsln' |
|
|
48 |
|
|
|
49 |
if maxRR: |
|
|
50 |
model_svm_path = model_svm_path + delimiter + 'maxRR' |
|
|
51 |
|
|
|
52 |
if use_RR: |
|
|
53 |
model_svm_path = model_svm_path + delimiter + 'RR' |
|
|
54 |
|
|
|
55 |
if norm_RR: |
|
|
56 |
model_svm_path = model_svm_path + delimiter + 'norm_RR' |
|
|
57 |
|
|
|
58 |
for descp in compute_morph: |
|
|
59 |
model_svm_path = model_svm_path + delimiter + descp |
|
|
60 |
|
|
|
61 |
if use_weight_class: |
|
|
62 |
model_svm_path = model_svm_path + delimiter + 'weighted' |
|
|
63 |
|
|
|
64 |
if pca_k > 0: |
|
|
65 |
model_svm_path = model_svm_path + delimiter + 'pca_' + str(pca_k) |
|
|
66 |
|
|
|
67 |
return model_svm_path |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
# Eval the SVM model and export the results |
|
|
71 |
def eval_model(svm_model, features, labels, multi_mode, voting_strategy, output_path, C_value, gamma_value, DS): |
|
|
72 |
if multi_mode == 'ovo': |
|
|
73 |
decision_ovo = svm_model.decision_function(features) |
|
|
74 |
|
|
|
75 |
if voting_strategy == 'ovo_voting': |
|
|
76 |
predict_ovo, counter = ovo_voting(decision_ovo, 4) |
|
|
77 |
|
|
|
78 |
elif voting_strategy == 'ovo_voting_both': |
|
|
79 |
predict_ovo, counter = ovo_voting_both(decision_ovo, 4) |
|
|
80 |
|
|
|
81 |
elif voting_strategy == 'ovo_voting_exp': |
|
|
82 |
predict_ovo, counter = ovo_voting_exp(decision_ovo, 4) |
|
|
83 |
|
|
|
84 |
# svm_model.predict_log_proba svm_model.predict_proba svm_model.predict ... |
|
|
85 |
perf_measures = compute_AAMI_performance_measures(predict_ovo, labels) |
|
|
86 |
|
|
|
87 |
""" |
|
|
88 |
elif multi_mode == 'ovr':cr |
|
|
89 |
decision_ovr = svm_model.decision_function(features) |
|
|
90 |
predict_ovr = svm_model.predict(features) |
|
|
91 |
perf_measures = compute_AAMI_performance_measures(predict_ovr, labels) |
|
|
92 |
""" |
|
|
93 |
|
|
|
94 |
# Write results and also predictions on DS2 |
|
|
95 |
if not os.path.exists(output_path): |
|
|
96 |
os.makedirs(output_path) |
|
|
97 |
|
|
|
98 |
if gamma_value != 0.0: |
|
|
99 |
write_AAMI_results( perf_measures, output_path + '/' + DS + 'C_' + str(C_value) + 'g_' + str(gamma_value) + |
|
|
100 |
'_score_Ijk_' + str(format(perf_measures.Ijk, '.2f')) + '_' + voting_strategy + '.txt') |
|
|
101 |
else: |
|
|
102 |
write_AAMI_results( perf_measures, output_path + '/' + DS + 'C_' + str(C_value) + |
|
|
103 |
'_score_Ijk_' + str(format(perf_measures.Ijk, '.2f')) + '_' + voting_strategy + '.txt') |
|
|
104 |
|
|
|
105 |
# Array to .csv |
|
|
106 |
if multi_mode == 'ovo': |
|
|
107 |
if gamma_value != 0.0: |
|
|
108 |
np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + 'g_' + str(gamma_value) + |
|
|
109 |
'_decision_ovo.csv', decision_ovo) |
|
|
110 |
np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + 'g_' + str(gamma_value) + |
|
|
111 |
'_predict_' + voting_strategy + '.csv', predict_ovo.astype(int), '%.0f') |
|
|
112 |
else: |
|
|
113 |
np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + |
|
|
114 |
'_decision_ovo.csv', decision_ovo) |
|
|
115 |
np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + |
|
|
116 |
'_predict_' + voting_strategy + '.csv', predict_ovo.astype(int), '%.0f') |
|
|
117 |
|
|
|
118 |
elif multi_mode == 'ovr': |
|
|
119 |
np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + |
|
|
120 |
'_decision_ovr.csv', prob_ovr) |
|
|
121 |
np.savetxt(output_path + '/' + DS + 'C_' + str(C_value) + |
|
|
122 |
'_predict_' + voting_strategy + '.csv', predict_ovr.astype(int), '%.0f') |
|
|
123 |
|
|
|
124 |
print("Results writed at " + output_path + '/' + DS + 'C_' + str(C_value)) |
|
|
125 |
|
|
|
126 |
|
|
|
127 |
|
|
|
128 |
def create_oversamp_name(reduced_DS, do_preprocess, compute_morph, winL, winR, maxRR, use_RR, norm_RR, pca_k): |
|
|
129 |
oversamp_features_pickle_name = '' |
|
|
130 |
if reduced_DS: |
|
|
131 |
oversamp_features_pickle_name += '_reduced_' |
|
|
132 |
|
|
|
133 |
if do_preprocess: |
|
|
134 |
oversamp_features_pickle_name += '_rm_bsline' |
|
|
135 |
|
|
|
136 |
if maxRR: |
|
|
137 |
oversamp_features_pickle_name += '_maxRR' |
|
|
138 |
|
|
|
139 |
if use_RR: |
|
|
140 |
oversamp_features_pickle_name += '_RR' |
|
|
141 |
|
|
|
142 |
if norm_RR: |
|
|
143 |
oversamp_features_pickle_name += '_norm_RR' |
|
|
144 |
|
|
|
145 |
for descp in compute_morph: |
|
|
146 |
oversamp_features_pickle_name += '_' + descp |
|
|
147 |
|
|
|
148 |
if pca_k > 0: |
|
|
149 |
oversamp_features_pickle_name += '_pca_' + str(pca_k) |
|
|
150 |
|
|
|
151 |
oversamp_features_pickle_name += '_wL_' + str(winL) + '_wR_' + str(winR) |
|
|
152 |
|
|
|
153 |
return oversamp_features_pickle_name |
|
|
154 |
|
|
|
155 |
|
|
|
156 |
|
|
|
157 |
def main(multi_mode='ovo', winL=90, winR=90, do_preprocess=True, use_weight_class=True, |
|
|
158 |
maxRR=True, use_RR=True, norm_RR=True, compute_morph={''}, oversamp_method = '', pca_k = '', feature_selection = '', do_cross_val = '', C_value = 0.001, gamma_value = 0.0, reduced_DS = False, leads_flag = [1,0]): |
|
|
159 |
print("Runing train_SVM.py!") |
|
|
160 |
|
|
|
161 |
db_path = '/home/mondejar/dataset/ECG/mitdb/m_learning/scikit/' |
|
|
162 |
|
|
|
163 |
# Load train data |
|
|
164 |
[tr_features, tr_labels, tr_patient_num_beats] = load_mit_db('DS1', winL, winR, do_preprocess, |
|
|
165 |
maxRR, use_RR, norm_RR, compute_morph, db_path, reduced_DS, leads_flag) |
|
|
166 |
|
|
|
167 |
# Load Test data |
|
|
168 |
[eval_features, eval_labels, eval_patient_num_beats] = load_mit_db('DS2', winL, winR, do_preprocess, |
|
|
169 |
maxRR, use_RR, norm_RR, compute_morph, db_path, reduced_DS, leads_flag) |
|
|
170 |
if reduced_DS == True: |
|
|
171 |
np.savetxt('mit_db/' + 'exp_2_' + 'DS2_labels.csv', eval_labels.astype(int), '%.0f') |
|
|
172 |
else: |
|
|
173 |
np.savetxt('mit_db/' + 'DS2_labels.csv', eval_labels.astype(int), '%.0f') |
|
|
174 |
|
|
|
175 |
#if reduced_DS == True: |
|
|
176 |
# np.savetxt('mit_db/' + 'exp_2_' + 'DS1_labels.csv', tr_labels.astype(int), '%.0f') |
|
|
177 |
#else: |
|
|
178 |
#np.savetxt('mit_db/' + 'DS1_labels.csv', tr_labels.astype(int), '%.0f') |
|
|
179 |
|
|
|
180 |
############################################################## |
|
|
181 |
# 0) TODO if feature_Selection: |
|
|
182 |
# before oversamp!!????? |
|
|
183 |
|
|
|
184 |
# TODO perform normalization before the oversampling? |
|
|
185 |
if oversamp_method: |
|
|
186 |
# Filename |
|
|
187 |
oversamp_features_pickle_name = create_oversamp_name(reduced_DS, do_preprocess, compute_morph, winL, winR, maxRR, use_RR, norm_RR, pca_k) |
|
|
188 |
|
|
|
189 |
# Do oversampling |
|
|
190 |
tr_features, tr_labels = perform_oversampling(oversamp_method, db_path + 'oversamp/python_mit', oversamp_features_pickle_name, tr_features, tr_labels) |
|
|
191 |
|
|
|
192 |
# Normalization of the input data |
|
|
193 |
# scaled: zero mean unit variance ( z-score ) |
|
|
194 |
scaler = StandardScaler() |
|
|
195 |
scaler.fit(tr_features) |
|
|
196 |
tr_features_scaled = scaler.transform(tr_features) |
|
|
197 |
|
|
|
198 |
# scaled: zero mean unit variance ( z-score ) |
|
|
199 |
eval_features_scaled = scaler.transform(eval_features) |
|
|
200 |
############################################################## |
|
|
201 |
# 0) ????????????? feature_Selection: also after Oversampling??? |
|
|
202 |
if feature_selection: |
|
|
203 |
print("Runing feature selection") |
|
|
204 |
best_features = 7 |
|
|
205 |
tr_features_scaled, features_index_sorted = run_feature_selection(tr_features_scaled, tr_labels, feature_selection, best_features) |
|
|
206 |
eval_features_scaled = eval_features_scaled[:, features_index_sorted[0:best_features]] |
|
|
207 |
# 1) |
|
|
208 |
if pca_k > 0: |
|
|
209 |
|
|
|
210 |
# Load if exists?? |
|
|
211 |
# NOTE PCA do memory error! |
|
|
212 |
|
|
|
213 |
# NOTE 11 Enero: TEST WITH IPCA!!!!!! |
|
|
214 |
start = time.time() |
|
|
215 |
|
|
|
216 |
print("Runing IPCA " + str(pca_k) + "...") |
|
|
217 |
|
|
|
218 |
# Run PCA |
|
|
219 |
IPCA = sklearn.decomposition.IncrementalPCA(pca_k, batch_size=pca_k) # gamma_pca |
|
|
220 |
|
|
|
221 |
#tr_features_scaled = KPCA.fit_transform(tr_features_scaled) |
|
|
222 |
IPCA.fit(tr_features_scaled) |
|
|
223 |
|
|
|
224 |
# Apply PCA on test data! |
|
|
225 |
tr_features_scaled = IPCA.transform(tr_features_scaled) |
|
|
226 |
eval_features_scaled = IPCA.transform(eval_features_scaled) |
|
|
227 |
|
|
|
228 |
""" |
|
|
229 |
print("Runing TruncatedSVD (singular value decomposition (SVD)!!!) (alternative to PCA) " + str(pca_k) + "...") |
|
|
230 |
|
|
|
231 |
svd = decomposition.TruncatedSVD(n_components=pca_k, algorithm='arpack') |
|
|
232 |
svd.fit(tr_features_scaled) |
|
|
233 |
tr_features_scaled = svd.transform(tr_features_scaled) |
|
|
234 |
eval_features_scaled = svd.transform(eval_features_scaled) |
|
|
235 |
|
|
|
236 |
""" |
|
|
237 |
end = time.time() |
|
|
238 |
|
|
|
239 |
print("Time runing IPCA (rbf): " + str(format(end - start, '.2f')) + " sec" ) |
|
|
240 |
############################################################## |
|
|
241 |
# 2) Cross-validation: |
|
|
242 |
|
|
|
243 |
if do_cross_val: |
|
|
244 |
print("Runing cross val...") |
|
|
245 |
start = time.time() |
|
|
246 |
|
|
|
247 |
# TODO Save data over the k-folds and ranked by the best average values in separated files |
|
|
248 |
perf_measures_path = create_svm_model_name('/home/mondejar/Dropbox/ECG/code/ecg_classification/python/results/' + multi_mode, winL, winR, do_preprocess, |
|
|
249 |
maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection, oversamp_method, leads_flag, reduced_DS, pca_k, '/') |
|
|
250 |
|
|
|
251 |
# TODO implement this method! check to avoid NaN scores.... |
|
|
252 |
|
|
|
253 |
if do_cross_val == 'pat_cv': # Cross validation with one fold per patient |
|
|
254 |
cv_scores, c_values = run_cross_val(tr_features_scaled, tr_labels, tr_patient_num_beats, do_cross_val, len(tr_patient_num_beats)) |
|
|
255 |
|
|
|
256 |
if not os.path.exists(perf_measures_path): |
|
|
257 |
os.makedirs(perf_measures_path) |
|
|
258 |
np.savetxt(perf_measures_path + '/cross_val_k-pat_cv_F_score.csv', (c_values, cv_scores.astype(float)), "%f") |
|
|
259 |
|
|
|
260 |
elif do_cross_val == 'beat_cv': # cross validation by class id samples |
|
|
261 |
k_folds = {5} |
|
|
262 |
for k in k_folds: |
|
|
263 |
ijk_scores, c_values = run_cross_val(tr_features_scaled, tr_labels, tr_patient_num_beats, do_cross_val, k) |
|
|
264 |
# TODO Save data over the k-folds and ranked by the best average values in separated files |
|
|
265 |
perf_measures_path = create_svm_model_name('/home/mondejar/Dropbox/ECG/code/ecg_classification/python/results/' + multi_mode, winL, winR, do_preprocess, |
|
|
266 |
maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection, oversamp_method, leads_flag, reduced_DS, pca_k, '/') |
|
|
267 |
|
|
|
268 |
if not os.path.exists(perf_measures_path): |
|
|
269 |
os.makedirs(perf_measures_path) |
|
|
270 |
np.savetxt(perf_measures_path + '/cross_val_k-' + str(k) + '_Ijk_score.csv', (c_values, ijk_scores.astype(float)), "%f") |
|
|
271 |
|
|
|
272 |
end = time.time() |
|
|
273 |
print("Time runing Cross Validation: " + str(format(end - start, '.2f')) + " sec" ) |
|
|
274 |
else: |
|
|
275 |
|
|
|
276 |
################################################################################################ |
|
|
277 |
# 3) Train SVM model |
|
|
278 |
|
|
|
279 |
# TODO load best params from cross validation! |
|
|
280 |
|
|
|
281 |
use_probability = False |
|
|
282 |
|
|
|
283 |
model_svm_path = db_path + 'svm_models/' + multi_mode + '_rbf' |
|
|
284 |
|
|
|
285 |
model_svm_path = create_svm_model_name(model_svm_path, winL, winR, do_preprocess, |
|
|
286 |
maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection, |
|
|
287 |
oversamp_method, leads_flag, reduced_DS, pca_k, '_') |
|
|
288 |
|
|
|
289 |
if gamma_value != 0.0: |
|
|
290 |
model_svm_path = model_svm_path + '_C_' + str(C_value) + '_g_' + str(gamma_value) +'.joblib.pkl' |
|
|
291 |
else: |
|
|
292 |
model_svm_path = model_svm_path + '_C_' + str(C_value) + '.joblib.pkl' |
|
|
293 |
|
|
|
294 |
print("Training model on MIT-BIH DS1: " + model_svm_path + "...") |
|
|
295 |
|
|
|
296 |
if os.path.isfile(model_svm_path): |
|
|
297 |
# Load the trained model! |
|
|
298 |
svm_model = joblib.load(model_svm_path) |
|
|
299 |
|
|
|
300 |
else: |
|
|
301 |
class_weights = {} |
|
|
302 |
for c in range(4): |
|
|
303 |
class_weights.update({c:len(tr_labels) / float(np.count_nonzero(tr_labels == c))}) |
|
|
304 |
|
|
|
305 |
#class_weight='balanced', |
|
|
306 |
if gamma_value != 0.0: # NOTE 0.0 means 1/n_features default value |
|
|
307 |
svm_model = svm.SVC(C=C_value, kernel='rbf', degree=3, gamma=gamma_value, |
|
|
308 |
coef0=0.0, shrinking=True, probability=use_probability, tol=0.001, |
|
|
309 |
cache_size=200, class_weight=class_weights, verbose=False, |
|
|
310 |
max_iter=-1, decision_function_shape=multi_mode, random_state=None) |
|
|
311 |
else: |
|
|
312 |
svm_model = svm.SVC(C=C_value, kernel='rbf', degree=3, gamma='auto', |
|
|
313 |
coef0=0.0, shrinking=True, probability=use_probability, tol=0.001, |
|
|
314 |
cache_size=200, class_weight=class_weights, verbose=False, |
|
|
315 |
max_iter=-1, decision_function_shape=multi_mode, random_state=None) |
|
|
316 |
|
|
|
317 |
# Let's Train! |
|
|
318 |
|
|
|
319 |
start = time.time() |
|
|
320 |
svm_model.fit(tr_features_scaled, tr_labels) |
|
|
321 |
end = time.time() |
|
|
322 |
# TODO assert that the class_ID appears with the desired order, |
|
|
323 |
# with the goal of ovo make the combinations properly |
|
|
324 |
print("Trained completed!\n\t" + model_svm_path + "\n \ |
|
|
325 |
\tTime required: " + str(format(end - start, '.2f')) + " sec" ) |
|
|
326 |
|
|
|
327 |
# Export model: save/write trained SVM model |
|
|
328 |
joblib.dump(svm_model, model_svm_path) |
|
|
329 |
|
|
|
330 |
# TODO Export StandardScaler() |
|
|
331 |
|
|
|
332 |
######################################################################### |
|
|
333 |
# 4) Test SVM model |
|
|
334 |
print("Testing model on MIT-BIH DS2: " + model_svm_path + "...") |
|
|
335 |
|
|
|
336 |
############################################################################################################ |
|
|
337 |
# EVALUATION |
|
|
338 |
############################################################################################################ |
|
|
339 |
|
|
|
340 |
# Evaluate the model on the training data |
|
|
341 |
perf_measures_path = create_svm_model_name('/home/mondejar/Dropbox/ECG/code/ecg_classification/python/results/' + multi_mode, winL, winR, do_preprocess, |
|
|
342 |
maxRR, use_RR, norm_RR, compute_morph, use_weight_class, feature_selection, oversamp_method, leads_flag, reduced_DS, pca_k, '/') |
|
|
343 |
|
|
|
344 |
# ovo_voting: |
|
|
345 |
# Simply add 1 to the win class |
|
|
346 |
print("Evaluation on DS1 ...") |
|
|
347 |
eval_model(svm_model, tr_features_scaled, tr_labels, multi_mode, 'ovo_voting', perf_measures_path, C_value, gamma_value, 'Train_') |
|
|
348 |
|
|
|
349 |
# Let's test new data! |
|
|
350 |
print("Evaluation on DS2 ...") |
|
|
351 |
eval_model(svm_model, eval_features_scaled, eval_labels, multi_mode, 'ovo_voting', perf_measures_path, C_value, gamma_value, '') |
|
|
352 |
|
|
|
353 |
|
|
|
354 |
# ovo_voting_exp: |
|
|
355 |
# Consider the post prob adding to both classes |
|
|
356 |
print("Evaluation on DS1 ...") |
|
|
357 |
eval_model(svm_model, tr_features_scaled, tr_labels, multi_mode, 'ovo_voting_exp', perf_measures_path, C_value, gamma_value, 'Train_') |
|
|
358 |
|
|
|
359 |
# Let's test new data! |
|
|
360 |
print("Evaluation on DS2 ...") |
|
|
361 |
eval_model(svm_model, eval_features_scaled, eval_labels, multi_mode, 'ovo_voting_exp', perf_measures_path, C_value, gamma_value, '') |