|
a |
|
b/src/biodiscml/BestModelSelectionAndReport.java |
|
|
1 |
/* |
|
|
2 |
* Select only models getting specific min MCC or max RMSE |
|
|
3 |
* Select all features from selected models |
|
|
4 |
* Retrain a model (best classifier of all tested) with the unique selected features with LOOCV 75/25. |
|
|
5 |
* Do it 10 times with various seeds, report the average scores (ex: AUC) |
|
|
6 |
* Explore biology behind the set of selected features |
|
|
7 |
|
|
|
8 |
*/ |
|
|
9 |
package biodiscml; |
|
|
10 |
|
|
|
11 |
import java.io.BufferedReader; |
|
|
12 |
import java.io.File; |
|
|
13 |
import java.io.FileReader; |
|
|
14 |
import java.io.FileWriter; |
|
|
15 |
import java.io.PrintWriter; |
|
|
16 |
import java.text.DecimalFormat; |
|
|
17 |
import java.text.DecimalFormatSymbols; |
|
|
18 |
import java.util.ArrayList; |
|
|
19 |
import java.util.Arrays; |
|
|
20 |
import java.util.Collections; |
|
|
21 |
import java.util.HashMap; |
|
|
22 |
import java.util.LinkedHashMap; |
|
|
23 |
import java.util.TreeMap; |
|
|
24 |
import org.apache.commons.math3.stat.descriptive.moment.Mean; |
|
|
25 |
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; |
|
|
26 |
import utils.UpSetR; |
|
|
27 |
import utils.Weka_module; |
|
|
28 |
import utils.utils; |
|
|
29 |
import weka.core.SerializationHelper; |
|
|
30 |
|
|
|
31 |
/** |
|
|
32 |
* |
|
|
33 |
* @author Mickael |
|
|
34 |
*/ |
|
|
35 |
public class BestModelSelectionAndReport { |
|
|
36 |
|
|
|
37 |
public static String wd = Main.wd; |
|
|
38 |
public static Weka_module weka = new Weka_module(); |
|
|
39 |
public static HashMap<String, Integer> hmResultsHeaderNames = new HashMap<>(); |
|
|
40 |
public static HashMap< Integer, String> hmResultsHeaderIndexes = new HashMap<>(); |
|
|
41 |
public static DecimalFormat df = new DecimalFormat(); |
|
|
42 |
public static String trainFileName; |
|
|
43 |
public static String featureSelectionFile; |
|
|
44 |
public static String predictionsResultsFile; |
|
|
45 |
public static String correlatedFeatures; |
|
|
46 |
|
|
|
47 |
/** |
|
|
48 |
* |
|
|
49 |
* @param trainFilName |
|
|
50 |
* @param featureSelFile |
|
|
51 |
* @param predictionsResFile |
|
|
52 |
* @param type |
|
|
53 |
*/ |
|
|
54 |
public BestModelSelectionAndReport(String trainFilName, |
|
|
55 |
String featureSelFile, |
|
|
56 |
String predictionsResFile, |
|
|
57 |
String type |
|
|
58 |
) { |
|
|
59 |
trainFileName = trainFilName; |
|
|
60 |
if (Main.noFeatureSelection) { |
|
|
61 |
featureSelectionFile = trainFilName; |
|
|
62 |
} else { |
|
|
63 |
featureSelectionFile = featureSelFile; |
|
|
64 |
} |
|
|
65 |
predictionsResultsFile = predictionsResFile; |
|
|
66 |
df.setMaximumFractionDigits(3); |
|
|
67 |
DecimalFormatSymbols dfs = new DecimalFormatSymbols(); |
|
|
68 |
dfs.setDecimalSeparator('.'); |
|
|
69 |
df.setDecimalFormatSymbols(dfs); |
|
|
70 |
String bestOrCombine = "Select best "; |
|
|
71 |
|
|
|
72 |
if (Main.combineModels) { |
|
|
73 |
bestOrCombine = "Combine "; |
|
|
74 |
} |
|
|
75 |
String sign = " >= "; |
|
|
76 |
boolean metricToMinimize = (Main.bestModelsSortingMetric.contains("RMSE") |
|
|
77 |
|| Main.bestModelsSortingMetric.contains("BER") |
|
|
78 |
|| Main.bestModelsSortingMetric.contains("FPR") |
|
|
79 |
|| Main.bestModelsSortingMetric.contains("FNR") |
|
|
80 |
|| Main.bestModelsSortingMetric.contains("FDR") |
|
|
81 |
|| Main.bestModelsSortingMetric.contains("MAE") |
|
|
82 |
|| Main.bestModelsSortingMetric.contains("RAE") |
|
|
83 |
|| Main.bestModelsSortingMetric.contains("RRSE")); |
|
|
84 |
if (metricToMinimize) { |
|
|
85 |
sign = " <= "; |
|
|
86 |
} |
|
|
87 |
|
|
|
88 |
System.out.println("## " + bestOrCombine + " models using " + Main.bestModelsSortingMetric + " as sorting metric.\n" |
|
|
89 |
+ "## Parameters: " + Main.numberOfBestModels + " best models and " |
|
|
90 |
+ Main.bestModelsSortingMetric + sign + Main.bestModelsSortingMetricThreshold); |
|
|
91 |
|
|
|
92 |
//Read results file |
|
|
93 |
boolean classification = type.equals("classification"); |
|
|
94 |
|
|
|
95 |
try { |
|
|
96 |
BufferedReader br = new BufferedReader(new FileReader(predictionsResultsFile)); |
|
|
97 |
TreeMap<String, Object> tmModels = new TreeMap<>(); //<metric modelID, classification/regression Object> |
|
|
98 |
HashMap<String, Object> hmModels = new HashMap<>(); //<modelID, classification/regression Object> |
|
|
99 |
|
|
|
100 |
//in case of RMSE or BER, we want the minimum value instead of the maximal one |
|
|
101 |
if (!metricToMinimize) { |
|
|
102 |
tmModels = new TreeMap<>(Collections.reverseOrder()); |
|
|
103 |
} |
|
|
104 |
String line = br.readLine(); |
|
|
105 |
|
|
|
106 |
//fill header mapping |
|
|
107 |
String header = line; |
|
|
108 |
int cpt = 0; |
|
|
109 |
for (String s : header.split("\t")) { |
|
|
110 |
hmResultsHeaderNames.put(s, cpt); |
|
|
111 |
hmResultsHeaderIndexes.put(cpt, s); |
|
|
112 |
cpt++; |
|
|
113 |
} |
|
|
114 |
if (!hmResultsHeaderNames.containsKey(Main.bestModelsSortingMetric)) { |
|
|
115 |
System.err.println("[error] " + Main.bestModelsSortingMetric + " column does not exist in the results file."); |
|
|
116 |
if (Main.doRegression) { |
|
|
117 |
System.out.println("Use AVG_CC instead since we are in regression mode"); |
|
|
118 |
Main.bestModelsSortingMetric = "AVG_CC"; |
|
|
119 |
} else { |
|
|
120 |
System.exit(0); |
|
|
121 |
} |
|
|
122 |
|
|
|
123 |
} |
|
|
124 |
|
|
|
125 |
//read results |
|
|
126 |
while (br.ready()) { |
|
|
127 |
line = br.readLine(); |
|
|
128 |
if (!line.trim().isEmpty() && !line.contains("[model error]")) { |
|
|
129 |
if (classification) { |
|
|
130 |
try { |
|
|
131 |
classificationObject co = new classificationObject(line); |
|
|
132 |
tmModels.put(Double.valueOf(co.hmValues.get(Main.bestModelsSortingMetric)) + " " + co.hmValues.get("ID"), co); |
|
|
133 |
hmModels.put(co.hmValues.get("ID"), co); |
|
|
134 |
} catch (Exception e) { |
|
|
135 |
if (Main.debug) { |
|
|
136 |
e.printStackTrace(); |
|
|
137 |
} |
|
|
138 |
} |
|
|
139 |
} else { |
|
|
140 |
try { |
|
|
141 |
regressionObject ro = new regressionObject(line); |
|
|
142 |
tmModels.put(Double.valueOf(ro.hmValues.get(Main.bestModelsSortingMetric)) + " " + ro.hmValues.get("ID"), ro); |
|
|
143 |
hmModels.put(ro.hmValues.get("ID"), ro); |
|
|
144 |
} catch (Exception e) { |
|
|
145 |
if (Main.debug) { |
|
|
146 |
e.printStackTrace(); |
|
|
147 |
} |
|
|
148 |
} |
|
|
149 |
} |
|
|
150 |
} |
|
|
151 |
} |
|
|
152 |
br.close(); |
|
|
153 |
|
|
|
154 |
//control available models |
|
|
155 |
if (Main.numberOfBestModels > tmModels.size()) { |
|
|
156 |
System.out.println("Only " + tmModels.size() + " available models. You have configured " + Main.numberOfBestModels + " best models"); |
|
|
157 |
Main.numberOfBestModels = tmModels.size(); |
|
|
158 |
} |
|
|
159 |
|
|
|
160 |
// get best models list |
|
|
161 |
ArrayList<Object> alBestClassifiers = new ArrayList<>(); |
|
|
162 |
cpt = 0; |
|
|
163 |
if (Main.hmTrainingBestModelList.isEmpty()) { |
|
|
164 |
for (String metricAndModel : tmModels.keySet()) { |
|
|
165 |
cpt++; |
|
|
166 |
boolean condition = false; |
|
|
167 |
if (metricToMinimize) { |
|
|
168 |
condition = Double.valueOf(metricAndModel.split(" ")[0]) < Main.bestModelsSortingMetricThreshold; |
|
|
169 |
} else { |
|
|
170 |
condition = Double.valueOf(metricAndModel.split(" ")[0]) > Main.bestModelsSortingMetricThreshold; |
|
|
171 |
} |
|
|
172 |
if (condition && cpt <= Main.numberOfBestModels) { |
|
|
173 |
if (classification) { |
|
|
174 |
alBestClassifiers.add(((classificationObject) tmModels.get(metricAndModel))); |
|
|
175 |
} else { |
|
|
176 |
alBestClassifiers.add(((regressionObject) tmModels.get(metricAndModel))); |
|
|
177 |
} |
|
|
178 |
} |
|
|
179 |
} |
|
|
180 |
} else { |
|
|
181 |
for (String modelID : Main.hmTrainingBestModelList.keySet()) { |
|
|
182 |
if (classification) { |
|
|
183 |
alBestClassifiers.add(((classificationObject) hmModels.get(modelID))); |
|
|
184 |
} else { |
|
|
185 |
alBestClassifiers.add(((regressionObject) hmModels.get(modelID))); |
|
|
186 |
} |
|
|
187 |
} |
|
|
188 |
} |
|
|
189 |
|
|
|
190 |
//if model combination vote |
|
|
191 |
if (Main.combineModels) { |
|
|
192 |
if (classification) { |
|
|
193 |
classificationObject co = new classificationObject(); |
|
|
194 |
co.buildVoteClassifier(alBestClassifiers); |
|
|
195 |
alBestClassifiers = new ArrayList<>(); |
|
|
196 |
alBestClassifiers.add(co); |
|
|
197 |
} else { |
|
|
198 |
regressionObject ro = new regressionObject(); |
|
|
199 |
ro.buildVoteClassifier(alBestClassifiers); |
|
|
200 |
alBestClassifiers = new ArrayList<>(); |
|
|
201 |
alBestClassifiers.add(ro); |
|
|
202 |
} |
|
|
203 |
} |
|
|
204 |
|
|
|
205 |
//perform evaluations and create models |
|
|
206 |
PrintWriter pw = null; |
|
|
207 |
for (Object classifier : alBestClassifiers) { |
|
|
208 |
// initialize weka module |
|
|
209 |
if (classification) { |
|
|
210 |
init(featureSelectionFile.replace("infoGain.csv", "infoGain.arff"), classification); |
|
|
211 |
} else { |
|
|
212 |
init(featureSelectionFile.replace("RELIEFF.csv", "RELIEFF.arff"), classification); |
|
|
213 |
} |
|
|
214 |
createBestModel(classifier, classification, pw, br, false); |
|
|
215 |
|
|
|
216 |
if (Main.generateModelWithCorrelatedGenes) { |
|
|
217 |
init(trainFilName, classification); |
|
|
218 |
createBestModel(classifier, classification, pw, br, true); |
|
|
219 |
correlatedFeatures = null; |
|
|
220 |
} |
|
|
221 |
} |
|
|
222 |
|
|
|
223 |
} catch (ClassCastException e) { |
|
|
224 |
e.printStackTrace(); |
|
|
225 |
System.err.println("Unable to train selected best model(s). Check input files. "); |
|
|
226 |
} catch (Exception e) { |
|
|
227 |
e.printStackTrace(); |
|
|
228 |
} |
|
|
229 |
} |
|
|
230 |
|
|
|
231 |
private void createBestModel(Object classifier, |
|
|
232 |
Boolean classification, |
|
|
233 |
PrintWriter pw, BufferedReader br, Boolean correlatedFeaturesMode) throws Exception { |
|
|
234 |
String corrMode = ""; |
|
|
235 |
if (correlatedFeaturesMode) { |
|
|
236 |
corrMode = "_corr"; |
|
|
237 |
System.out.print("\n# Model with correlated features "); |
|
|
238 |
((classificationObject) classifier).featuresSeparatedByCommas = correlatedFeatures; |
|
|
239 |
} else { |
|
|
240 |
System.out.print("\n# Model "); |
|
|
241 |
} |
|
|
242 |
ArrayList<Double> alMCCs = new ArrayList<>(); |
|
|
243 |
ArrayList<Double> alMAEs = new ArrayList<>(); |
|
|
244 |
ArrayList<Double> alCCs = new ArrayList<>(); |
|
|
245 |
|
|
|
246 |
if (Main.debug) { |
|
|
247 |
System.out.println("Save model "); |
|
|
248 |
} |
|
|
249 |
String modelFilename; |
|
|
250 |
Weka_module.ClassificationResultsObject cr = null; |
|
|
251 |
Weka_module.RegressionResultsObject rr = null; |
|
|
252 |
classificationObject co = null; |
|
|
253 |
regressionObject ro = null; |
|
|
254 |
String classifierName = ""; |
|
|
255 |
|
|
|
256 |
//saving files |
|
|
257 |
if (classification) { |
|
|
258 |
co = (classificationObject) classifier; |
|
|
259 |
//train model |
|
|
260 |
modelFilename = Main.wd + Main.project |
|
|
261 |
+ "d." + co.classifier + "_" + co.printOptions() + "_" |
|
|
262 |
+ co.optimizer.toUpperCase().trim() + "_" + co.mode + corrMode; |
|
|
263 |
Object trainingOutput = weka.trainClassifier(co.classifier, co.options, |
|
|
264 |
co.featuresSeparatedByCommas, classification, 10); |
|
|
265 |
cr = (Weka_module.ClassificationResultsObject) trainingOutput; |
|
|
266 |
|
|
|
267 |
classifierName = co.classifier + "_" + co.printOptions() + "_" |
|
|
268 |
+ co.optimizer.toUpperCase().trim() + "_" + co.mode; |
|
|
269 |
//save feature file |
|
|
270 |
weka.saveFilteredDataToCSV(co.featuresSeparatedByCommas, classification, modelFilename + ".train_features.csv"); |
|
|
271 |
//call ranking function |
|
|
272 |
cr.featuresRankingResults = weka.featureRankingForClassification(modelFilename + ".train_features.csv"); |
|
|
273 |
//save model |
|
|
274 |
try { |
|
|
275 |
SerializationHelper.write(modelFilename + ".model", cr.model); |
|
|
276 |
} catch (Exception e) { |
|
|
277 |
e.printStackTrace(); |
|
|
278 |
} |
|
|
279 |
System.out.println(modelFilename); |
|
|
280 |
} else { |
|
|
281 |
ro = (regressionObject) classifier; |
|
|
282 |
modelFilename = Main.wd + Main.project |
|
|
283 |
+ "d." + ro.classifier + "_" + ro.printOptions() + "_" |
|
|
284 |
+ ro.optimizer.toUpperCase().trim() + "_" + ro.mode; |
|
|
285 |
Object trainingOutput = weka.trainClassifier(ro.classifier, ro.options, |
|
|
286 |
ro.featuresSeparatedByCommas, classification, 10); |
|
|
287 |
rr = (Weka_module.RegressionResultsObject) trainingOutput; |
|
|
288 |
|
|
|
289 |
classifierName = ro.classifier + "_" + ro.printOptions() + "_" |
|
|
290 |
+ ro.optimizer.toUpperCase().trim() + "_" + ro.mode; |
|
|
291 |
|
|
|
292 |
//save model and features |
|
|
293 |
try { |
|
|
294 |
SerializationHelper.write(modelFilename + ".model", rr.model); |
|
|
295 |
} catch (Exception e) { |
|
|
296 |
e.printStackTrace(); |
|
|
297 |
} |
|
|
298 |
//save feature file |
|
|
299 |
weka.saveFilteredDataToCSV(ro.featuresSeparatedByCommas, classification, modelFilename + ".train_features.csv"); |
|
|
300 |
//call ranking function |
|
|
301 |
rr.featuresRankingResults = weka.featureRankingForRegression(modelFilename + ".train_features.csv"); |
|
|
302 |
System.out.println(modelFilename); |
|
|
303 |
} |
|
|
304 |
|
|
|
305 |
//header |
|
|
306 |
pw = new PrintWriter(new FileWriter(modelFilename + ".details.txt")); |
|
|
307 |
pw.println("## Generated by BioDiscML (Leclercq et al. 2019)##"); |
|
|
308 |
pw.println("# Project: " + Main.project.substring(0, Main.project.length() - 1)); |
|
|
309 |
|
|
|
310 |
if (classification) { |
|
|
311 |
pw.println("# ID: " + co.identifier); |
|
|
312 |
System.out.println("# ID: " + co.identifier); |
|
|
313 |
pw.println("# Classifier: " + co.classifier + " " + co.options |
|
|
314 |
+ "\n# Optimizer: " + co.optimizer.toUpperCase() |
|
|
315 |
+ "\n# Feature search mode: " + co.mode); |
|
|
316 |
} else { |
|
|
317 |
pw.println("# ID: " + ro.identifier); |
|
|
318 |
System.out.println("# ID: " + ro.identifier); |
|
|
319 |
pw.println("# Classifier: " + ro.classifier + " " + ro.options |
|
|
320 |
+ "\n# Optimizer: " + ro.optimizer.toUpperCase() |
|
|
321 |
+ "\n# Feature search mode: " + ro.mode); |
|
|
322 |
} |
|
|
323 |
|
|
|
324 |
//show combined models in case of combined vote |
|
|
325 |
if (Main.combineModels) { |
|
|
326 |
pw.println("# Combined classifiers:"); |
|
|
327 |
String combOpt = ""; |
|
|
328 |
if (classification) { |
|
|
329 |
combOpt = co.options; |
|
|
330 |
} else { |
|
|
331 |
combOpt = ro.options; |
|
|
332 |
} |
|
|
333 |
for (String s : combOpt.split("-B ")) { |
|
|
334 |
if (s.startsWith("\"weka.classifiers.meta.FilteredClassifier")) { |
|
|
335 |
String usedFeatures = s.split("Remove -V -R ")[1]; |
|
|
336 |
usedFeatures = usedFeatures.split("\\\\")[0]; |
|
|
337 |
String model = s.substring(s.indexOf("-W ") + 2).trim() |
|
|
338 |
.replace("-- ", "") |
|
|
339 |
.replace("\\\"", "\"") |
|
|
340 |
.replace("\\\\\"", "\\\""); |
|
|
341 |
model = model.substring(0, model.length() - 1); |
|
|
342 |
pw.println(model + " (features: " + usedFeatures + ")"); |
|
|
343 |
} |
|
|
344 |
} |
|
|
345 |
} |
|
|
346 |
pw.flush(); |
|
|
347 |
|
|
|
348 |
//UpSetR |
|
|
349 |
if (Main.UpSetR) { |
|
|
350 |
UpSetR up = new UpSetR(); |
|
|
351 |
up.creatUpSetRDatasetFromSignature(co, featureSelectionFile, predictionsResultsFile); |
|
|
352 |
} |
|
|
353 |
|
|
|
354 |
//10CV performance |
|
|
355 |
System.out.println("# 10 fold cross validation performance"); |
|
|
356 |
pw.println("\n# 10 fold cross validation performance"); |
|
|
357 |
if (classification) { |
|
|
358 |
System.out.println(cr.toStringDetails()); |
|
|
359 |
alMCCs.add(Double.valueOf(cr.MCC)); |
|
|
360 |
alMAEs.add(Double.valueOf(cr.MAE)); |
|
|
361 |
pw.println(cr.toStringDetails().replace("[score_training] ", "")); |
|
|
362 |
} else { |
|
|
363 |
System.out.println(rr.toStringDetails()); |
|
|
364 |
alCCs.add(Double.valueOf(rr.CC)); |
|
|
365 |
alMAEs.add(Double.valueOf(rr.MAE)); |
|
|
366 |
pw.println(rr.toStringDetails().replace("[score_training] ", "")); |
|
|
367 |
} |
|
|
368 |
pw.flush(); |
|
|
369 |
|
|
|
370 |
//LOOCV performance |
|
|
371 |
if (Main.loocv) { |
|
|
372 |
System.out.println("# LOOCV (Leave-One-Out cross validation) performance"); |
|
|
373 |
pw.println("\n# LOOCV (Leave-One-Out Cross Validation) performance"); |
|
|
374 |
if (classification) { |
|
|
375 |
Weka_module.ClassificationResultsObject cr2 = (Weka_module.ClassificationResultsObject) weka.trainClassifier(co.classifier, co.options, |
|
|
376 |
co.featuresSeparatedByCommas, classification, weka.myData.numInstances()); |
|
|
377 |
System.out.println(cr2.toStringDetails()); |
|
|
378 |
alMCCs.add(Double.valueOf(cr2.MCC)); |
|
|
379 |
alMAEs.add(Double.valueOf(cr2.MAE)); |
|
|
380 |
pw.println(cr2.toStringDetails().replace("[score_training] ", "")); |
|
|
381 |
} else { |
|
|
382 |
Weka_module.RegressionResultsObject rr2 = (Weka_module.RegressionResultsObject) weka.trainClassifier(ro.classifier, ro.options, |
|
|
383 |
ro.featuresSeparatedByCommas, classification, weka.myData.numInstances()); |
|
|
384 |
System.out.println(rr2.toStringDetails()); |
|
|
385 |
alCCs.add(Double.valueOf(rr2.CC)); |
|
|
386 |
alMAEs.add(Double.valueOf(rr2.MAE)); |
|
|
387 |
pw.println(rr2.toStringDetails().replace("[score_training] ", "")); |
|
|
388 |
} |
|
|
389 |
pw.flush(); |
|
|
390 |
} |
|
|
391 |
|
|
|
392 |
//REPEATED HOLDOUT performance TRAIN set |
|
|
393 |
ArrayList<Object> alROCs = new ArrayList<>(); |
|
|
394 |
Weka_module.evaluationPerformancesResultsObject eproRHTrain = new Weka_module.evaluationPerformancesResultsObject(); |
|
|
395 |
if (classification) { |
|
|
396 |
System.out.println("Repeated Holdout evaluation on TRAIN set of " + co.classifier + " " + co.options |
|
|
397 |
+ " optimized by " + co.optimizer + "..."); |
|
|
398 |
pw.println("\n#Repeated Holdout evaluation performance on TRAIN set, " |
|
|
399 |
+ Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds"); |
|
|
400 |
for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) { |
|
|
401 |
Weka_module.ClassificationResultsObject cro |
|
|
402 |
= (Weka_module.ClassificationResultsObject) weka.trainClassifierHoldOutValidation(co.classifier, co.options, |
|
|
403 |
co.featuresSeparatedByCommas, classification, i); |
|
|
404 |
eproRHTrain.alAUCs.add(Double.valueOf(cro.AUC)); |
|
|
405 |
eproRHTrain.alpAUCs.add(Double.valueOf(cro.pAUC)); |
|
|
406 |
eproRHTrain.alAUPRCs.add(Double.valueOf(cro.AUPRC)); |
|
|
407 |
eproRHTrain.alACCs.add(Double.valueOf(cro.ACC)); |
|
|
408 |
eproRHTrain.alSEs.add(Double.valueOf(cro.TPR)); |
|
|
409 |
eproRHTrain.alSPs.add(Double.valueOf(cro.TNR)); |
|
|
410 |
eproRHTrain.alMCCs.add(Double.valueOf(cro.MCC)); |
|
|
411 |
eproRHTrain.alMAEs.add(Double.valueOf(cro.MAE)); |
|
|
412 |
eproRHTrain.alBERs.add(Double.valueOf(cro.BER)); |
|
|
413 |
alROCs.add(cro); |
|
|
414 |
// System.out.println(i+"\t"+Double.valueOf(cro.AUC)); |
|
|
415 |
} |
|
|
416 |
eproRHTrain.computeMeans(); |
|
|
417 |
System.out.println(eproRHTrain.toStringClassificationDetails()); |
|
|
418 |
alMCCs.add(Double.valueOf(eproRHTrain.meanMCCs)); |
|
|
419 |
alMAEs.add(Double.valueOf(eproRHTrain.meanMAEs)); |
|
|
420 |
pw.println(eproRHTrain.toStringClassificationDetails().replace("[score_training] ", "")); |
|
|
421 |
|
|
|
422 |
if (Main.ROCcurves) { |
|
|
423 |
rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc_train.png"); |
|
|
424 |
} |
|
|
425 |
} else { |
|
|
426 |
System.out.println("Repeated Holdout evaluation on TRAIN set of " + ro.classifier + " " |
|
|
427 |
+ ro.options + "optimized by " + ro.optimizer.toUpperCase()); |
|
|
428 |
pw.println("\n\n#Repeated Holdout evaluation performance on TRAIN set, " + Main.bootstrapAndRepeatedHoldoutFolds + " times average on random seeds"); |
|
|
429 |
for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) { |
|
|
430 |
Weka_module.RegressionResultsObject rro |
|
|
431 |
= (Weka_module.RegressionResultsObject) weka.trainClassifierHoldOutValidation(ro.classifier, ro.options, |
|
|
432 |
ro.featuresSeparatedByCommas, classification, i); |
|
|
433 |
eproRHTrain.alCCs.add(Double.valueOf(rro.CC)); |
|
|
434 |
eproRHTrain.alMAEs.add(Double.valueOf(rro.MAE)); |
|
|
435 |
eproRHTrain.alRMSEs.add(Double.valueOf(rro.RMSE)); |
|
|
436 |
eproRHTrain.alRAEs.add(Double.valueOf(rro.RAE)); |
|
|
437 |
eproRHTrain.alRRSEs.add(Double.valueOf(rro.RRSE)); |
|
|
438 |
} |
|
|
439 |
eproRHTrain.computeMeans(); |
|
|
440 |
alCCs.add(Double.valueOf(eproRHTrain.meanCCs)); |
|
|
441 |
alMAEs.add(Double.valueOf(eproRHTrain.meanMAEs)); |
|
|
442 |
pw.println(eproRHTrain.toStringRegressionDetails().replace("[score_training] ", "")); |
|
|
443 |
} |
|
|
444 |
pw.flush(); |
|
|
445 |
|
|
|
446 |
//BOOTSTRAP performance TRAIN set |
|
|
447 |
double bootstrapTrain632plus = -1; |
|
|
448 |
Weka_module.evaluationPerformancesResultsObject eproBSTrain = new Weka_module.evaluationPerformancesResultsObject(); |
|
|
449 |
if (classification) { |
|
|
450 |
System.out.println("Bootstrap evaluation on TRAIN set of " + co.classifier + " " + co.options |
|
|
451 |
+ " optimized by " + co.optimizer + "..."); |
|
|
452 |
pw.println("\n#Bootstrap evaluation performance on TRAIN set, " |
|
|
453 |
+ Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds"); |
|
|
454 |
for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) { |
|
|
455 |
Weka_module.ClassificationResultsObject cro |
|
|
456 |
= (Weka_module.ClassificationResultsObject) weka.trainClassifierBootstrap(co.classifier, co.options, |
|
|
457 |
co.featuresSeparatedByCommas, classification, i); |
|
|
458 |
eproBSTrain.alAUCs.add(Double.valueOf(cro.AUC)); |
|
|
459 |
eproBSTrain.alpAUCs.add(Double.valueOf(cro.pAUC)); |
|
|
460 |
eproBSTrain.alAUPRCs.add(Double.valueOf(cro.AUPRC)); |
|
|
461 |
eproBSTrain.alACCs.add(Double.valueOf(cro.ACC)); |
|
|
462 |
eproBSTrain.alSEs.add(Double.valueOf(cro.TPR)); |
|
|
463 |
eproBSTrain.alSPs.add(Double.valueOf(cro.TNR)); |
|
|
464 |
eproBSTrain.alMCCs.add(Double.valueOf(cro.MCC)); |
|
|
465 |
eproBSTrain.alMAEs.add(Double.valueOf(cro.MAE)); |
|
|
466 |
eproBSTrain.alBERs.add(Double.valueOf(cro.BER)); |
|
|
467 |
alROCs.add(cro); |
|
|
468 |
// System.out.println(i+"\t"+Double.valueOf(cro.AUC)); |
|
|
469 |
} |
|
|
470 |
eproBSTrain.computeMeans(); |
|
|
471 |
alMCCs.add(Double.valueOf(eproBSTrain.meanMCCs)); |
|
|
472 |
alMAEs.add(Double.valueOf(eproBSTrain.meanMAEs)); |
|
|
473 |
System.out.println(eproBSTrain.toStringClassificationDetails()); |
|
|
474 |
pw.println(eproBSTrain.toStringClassificationDetails().replace("[score_training] ", "")); |
|
|
475 |
|
|
|
476 |
//632+ rule |
|
|
477 |
System.out.println("Bootstrap .632+ rule calculated on TRAIN set of " + co.classifier + " " + co.options |
|
|
478 |
+ " optimized by " + co.optimizer + "..."); |
|
|
479 |
pw.println("\n#Bootstrap .632+ rule calculated on TRAIN set, " |
|
|
480 |
+ Main.bootstrapAndRepeatedHoldoutFolds + " folds with random seeds"); |
|
|
481 |
|
|
|
482 |
bootstrapTrain632plus = weka.trainClassifierBootstrap632plus(co.classifier, co.options, |
|
|
483 |
co.featuresSeparatedByCommas); |
|
|
484 |
System.out.println(df.format(bootstrapTrain632plus)); |
|
|
485 |
pw.println(df.format(bootstrapTrain632plus)); |
|
|
486 |
|
|
|
487 |
if (Main.ROCcurves) { |
|
|
488 |
rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc_train.png"); |
|
|
489 |
} |
|
|
490 |
} else { |
|
|
491 |
System.out.println("Bootstrap evaluation on TRAIN set of " + ro.classifier + " " |
|
|
492 |
+ ro.options + "optimized by " + ro.optimizer.toUpperCase()); |
|
|
493 |
pw.println("\n#Bootstrap evaluation performance on TRAIN set, " |
|
|
494 |
+ Main.bootstrapAndRepeatedHoldoutFolds + " times average on random seeds"); |
|
|
495 |
for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) { |
|
|
496 |
Weka_module.RegressionResultsObject rro |
|
|
497 |
= (Weka_module.RegressionResultsObject) weka.trainClassifierBootstrap(ro.classifier, ro.options, |
|
|
498 |
ro.featuresSeparatedByCommas, classification, i); |
|
|
499 |
eproBSTrain.alCCs.add(Double.valueOf(rro.CC)); |
|
|
500 |
eproBSTrain.alMAEs.add(Double.valueOf(rro.MAE)); |
|
|
501 |
eproBSTrain.alRMSEs.add(Double.valueOf(rro.RMSE)); |
|
|
502 |
eproBSTrain.alRAEs.add(Double.valueOf(rro.RAE)); |
|
|
503 |
eproBSTrain.alRRSEs.add(Double.valueOf(rro.RRSE)); |
|
|
504 |
} |
|
|
505 |
eproBSTrain.computeMeans(); |
|
|
506 |
alCCs.add(Double.valueOf(eproBSTrain.meanCCs)); |
|
|
507 |
alMAEs.add(Double.valueOf(eproBSTrain.meanMAEs)); |
|
|
508 |
pw.println(eproBSTrain.toStringRegressionDetails().replace("[score_training] ", "")); |
|
|
509 |
} |
|
|
510 |
pw.flush(); |
|
|
511 |
|
|
|
512 |
// IF TEST SET |
|
|
513 |
try { |
|
|
514 |
if (Main.doSampling) { |
|
|
515 |
alROCs = new ArrayList<>(); |
|
|
516 |
System.out.println("Evaluation performance on test set"); |
|
|
517 |
pw.println("\n#Evaluation performance on test set"); |
|
|
518 |
//get arff test filename generated before training |
|
|
519 |
String arffTestFile = trainFileName.replace("data_to_train.csv", "data_to_test.arff"); |
|
|
520 |
//set the outfile of the extracted features needed to test the current model |
|
|
521 |
String arffTestFileWithExtractedModelFeatures = modelFilename + ".test_features.arff"; |
|
|
522 |
//check if test set is here |
|
|
523 |
if (!new File(arffTestFile).exists()) { |
|
|
524 |
pw.println("Test file " + arffTestFile + " not found"); |
|
|
525 |
} |
|
|
526 |
//adapt test file to model (extract the needed features) |
|
|
527 |
//test file come from original dataset preprocessed in AdaptDatasetToTraining, so the arff is compatible |
|
|
528 |
Weka_module weka2 = new Weka_module(); |
|
|
529 |
weka2.setARFFfile(arffTestFile); |
|
|
530 |
weka2.setDataFromArff(); |
|
|
531 |
// create compatible test file |
|
|
532 |
if (Main.combineModels) { |
|
|
533 |
//combined model contains the filters, we need to keep the same |
|
|
534 |
//features indexes as the b.featureSelection.infoGain.arff |
|
|
535 |
weka2.extractFeaturesFromArffFileBasedOnSelectedFeatures(weka.myData, |
|
|
536 |
weka2.myData, arffTestFileWithExtractedModelFeatures); |
|
|
537 |
} else { |
|
|
538 |
weka2.extractFeaturesFromTestFileBasedOnModel(modelFilename + ".model", |
|
|
539 |
weka2.myData, arffTestFileWithExtractedModelFeatures); |
|
|
540 |
} |
|
|
541 |
|
|
|
542 |
//TESTING |
|
|
543 |
// reload compatible test file in weka2 |
|
|
544 |
weka2 = new Weka_module(); |
|
|
545 |
weka2.setARFFfile(arffTestFileWithExtractedModelFeatures); |
|
|
546 |
weka2.setDataFromArff(); |
|
|
547 |
|
|
|
548 |
if (classification) { |
|
|
549 |
Weka_module.ClassificationResultsObject cr2 |
|
|
550 |
= (Weka_module.ClassificationResultsObject) weka2.testClassifierFromFileSource(new File(weka2.ARFFfile), |
|
|
551 |
modelFilename + ".model", true); |
|
|
552 |
alROCs.add(cr2); |
|
|
553 |
alROCs.add(cr2); |
|
|
554 |
System.out.println("[score_testing] ACC: " + cr2.ACC); |
|
|
555 |
System.out.println("[score_testing] AUC: " + cr2.AUC); |
|
|
556 |
System.out.println("[score_testing] AUPRC: " + cr2.AUPRC); |
|
|
557 |
System.out.println("[score_testing] SEN: " + cr2.TPR); |
|
|
558 |
System.out.println("[score_testing] SPE: " + cr2.TNR); |
|
|
559 |
System.out.println("[score_testing] MCC: " + cr2.MCC); |
|
|
560 |
System.out.println("[score_testing] MAE: " + cr2.MAE); |
|
|
561 |
System.out.println("[score_testing] BER: " + cr2.BER); |
|
|
562 |
|
|
|
563 |
pw.println("AUC: " + cr2.AUC); |
|
|
564 |
pw.println("ACC: " + cr2.ACC); |
|
|
565 |
pw.println("AUPRC: " + cr2.AUPRC); |
|
|
566 |
pw.println("SEN: " + cr2.TPR); |
|
|
567 |
pw.println("SPE: " + cr2.TNR); |
|
|
568 |
pw.println("MCC: " + cr2.MCC); |
|
|
569 |
pw.println("MAE: " + cr2.MAE); |
|
|
570 |
pw.println("BER: " + cr2.BER); |
|
|
571 |
|
|
|
572 |
alMCCs.add(Double.valueOf(cr2.MCC)); |
|
|
573 |
alMAEs.add(Double.valueOf(cr2.MAE)); |
|
|
574 |
|
|
|
575 |
if (Main.ROCcurves) { |
|
|
576 |
rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc_test.png"); |
|
|
577 |
} |
|
|
578 |
} else { |
|
|
579 |
Weka_module.RegressionResultsObject rr2 |
|
|
580 |
= (Weka_module.RegressionResultsObject) weka2.testClassifierFromFileSource(new File(weka2.ARFFfile), |
|
|
581 |
modelFilename + ".model", false); |
|
|
582 |
System.out.println("[score_testing] Average CC: " + rr2.CC); |
|
|
583 |
System.out.println("[score_testing] Average RMSE: " + rr2.RMSE); |
|
|
584 |
// |
|
|
585 |
pw.println("Average CC: " + rr2.CC); |
|
|
586 |
pw.println("Average RMSE: " + rr2.RMSE); |
|
|
587 |
|
|
|
588 |
alCCs.add(Double.valueOf(rr2.CC)); |
|
|
589 |
alMAEs.add(Double.valueOf(rr2.MAE)); |
|
|
590 |
} |
|
|
591 |
new File(arffTestFileWithExtractedModelFeatures).delete(); |
|
|
592 |
|
|
|
593 |
//REPEATED HOLDOUT TRAIN_TEST |
|
|
594 |
arffTestFileWithExtractedModelFeatures = arffTestFileWithExtractedModelFeatures.replace(".test_features.arff", ".RH_features.arff"); |
|
|
595 |
Weka_module.evaluationPerformancesResultsObject eproRHTrainTest = new Weka_module.evaluationPerformancesResultsObject(); |
|
|
596 |
try { |
|
|
597 |
alROCs = new ArrayList<>(); |
|
|
598 |
// adapt original dataset file to model (extract the needed features) |
|
|
599 |
Weka_module weka3 = new Weka_module(); |
|
|
600 |
weka3.setARFFfile(trainFileName.replace("data_to_train.csv", "all_data.arff")); |
|
|
601 |
weka3.setDataFromArff(); |
|
|
602 |
// create compatible file |
|
|
603 |
if (Main.combineModels) { |
|
|
604 |
//combined model contains the filters, we need to keep the same |
|
|
605 |
//features indexes as the b.featureSelection.infoGain.arff |
|
|
606 |
weka3.extractFeaturesFromArffFileBasedOnSelectedFeatures(weka.myData, |
|
|
607 |
weka3.myData, arffTestFileWithExtractedModelFeatures); |
|
|
608 |
} else { |
|
|
609 |
weka3.extractFeaturesFromTestFileBasedOnModel(modelFilename + ".model", |
|
|
610 |
weka3.myData, arffTestFileWithExtractedModelFeatures); |
|
|
611 |
} |
|
|
612 |
|
|
|
613 |
// reload compatible file in weka2 |
|
|
614 |
weka3 = new Weka_module(); |
|
|
615 |
weka3.setARFFfile(arffTestFileWithExtractedModelFeatures); |
|
|
616 |
weka3.setDataFromArff(); |
|
|
617 |
|
|
|
618 |
if (classification) { |
|
|
619 |
if (!Main.combineModels) { |
|
|
620 |
weka3.myData = weka3.extractFeaturesFromDatasetBasedOnModel(cr.model, weka3.myData); |
|
|
621 |
} |
|
|
622 |
System.out.println("Repeated Holdout evaluation on TRAIN AND TEST sets of " + co.classifier + " " + co.options |
|
|
623 |
+ " optimized by " + co.optimizer); |
|
|
624 |
pw.println("\n#Repeated Holdout evaluation performance on TRAIN AND TEST set, " |
|
|
625 |
+ Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds"); |
|
|
626 |
for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) { |
|
|
627 |
Weka_module.ClassificationResultsObject cro |
|
|
628 |
= (Weka_module.ClassificationResultsObject) weka3.trainClassifierHoldOutValidation(co.classifier, co.options, |
|
|
629 |
null, classification, i); |
|
|
630 |
eproRHTrainTest.alAUCs.add(Double.valueOf(cro.AUC)); |
|
|
631 |
eproRHTrainTest.alpAUCs.add(Double.valueOf(cro.pAUC)); |
|
|
632 |
eproRHTrainTest.alAUPRCs.add(Double.valueOf(cro.AUPRC)); |
|
|
633 |
eproRHTrainTest.alACCs.add(Double.valueOf(cro.ACC)); |
|
|
634 |
eproRHTrainTest.alSEs.add(Double.valueOf(cro.TPR)); |
|
|
635 |
eproRHTrainTest.alSPs.add(Double.valueOf(cro.TNR)); |
|
|
636 |
eproRHTrainTest.alMCCs.add(Double.valueOf(cro.MCC)); |
|
|
637 |
eproRHTrainTest.alMAEs.add(Double.valueOf(cro.MAE)); |
|
|
638 |
eproRHTrainTest.alBERs.add(Double.valueOf(cro.BER)); |
|
|
639 |
alROCs.add(cro); |
|
|
640 |
} |
|
|
641 |
eproRHTrainTest.computeMeans(); |
|
|
642 |
alMCCs.add(Double.valueOf(eproRHTrainTest.meanMCCs)); |
|
|
643 |
alMAEs.add(Double.valueOf(eproRHTrainTest.meanMAEs)); |
|
|
644 |
System.out.println(eproRHTrainTest.toStringClassificationDetails()); |
|
|
645 |
pw.println(eproRHTrainTest.toStringClassificationDetails().replace("[score_training] ", "")); |
|
|
646 |
if (Main.ROCcurves) { |
|
|
647 |
rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc.png"); |
|
|
648 |
} |
|
|
649 |
} else { |
|
|
650 |
if (!Main.combineModels) { |
|
|
651 |
weka3.myData = weka3.extractFeaturesFromDatasetBasedOnModel(rr.model, weka3.myData); |
|
|
652 |
} |
|
|
653 |
|
|
|
654 |
System.out.println("Repeated Holdout evaluation on TRAIN AND TEST sets of " + ro.classifier + " " |
|
|
655 |
+ ro.options + "optimized by " + ro.optimizer.toUpperCase()); |
|
|
656 |
pw.println("\n#Repeated Holdout evaluation performance on TRAIN AND TEST set, " |
|
|
657 |
+ Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds"); |
|
|
658 |
for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) { |
|
|
659 |
Weka_module.RegressionResultsObject rro |
|
|
660 |
= (Weka_module.RegressionResultsObject) weka3.trainClassifierHoldOutValidation(ro.classifier, ro.options, |
|
|
661 |
null, classification, i); |
|
|
662 |
eproRHTrainTest.alCCs.add(Double.valueOf(rro.CC)); |
|
|
663 |
eproRHTrainTest.alMAEs.add(Double.valueOf(rro.MAE)); |
|
|
664 |
eproRHTrainTest.alRMSEs.add(Double.valueOf(rro.RMSE)); |
|
|
665 |
eproRHTrainTest.alRAEs.add(Double.valueOf(rro.RAE)); |
|
|
666 |
eproRHTrainTest.alRRSEs.add(Double.valueOf(rro.RRSE)); |
|
|
667 |
} |
|
|
668 |
eproRHTrainTest.computeMeans(); |
|
|
669 |
alCCs.add(Double.valueOf(eproRHTrainTest.meanCCs)); |
|
|
670 |
alMAEs.add(Double.valueOf(eproRHTrainTest.meanMAEs)); |
|
|
671 |
System.out.println(eproRHTrainTest.toStringRegressionDetails()); |
|
|
672 |
pw.println(eproRHTrainTest.toStringRegressionDetails().replace("[score_training] ", "")); |
|
|
673 |
} |
|
|
674 |
eproRHTrainTest.computeMeans(); |
|
|
675 |
} catch (Exception e) { |
|
|
676 |
if (Main.debug) { |
|
|
677 |
e.printStackTrace(); |
|
|
678 |
} |
|
|
679 |
} |
|
|
680 |
new File(arffTestFileWithExtractedModelFeatures).delete(); |
|
|
681 |
|
|
|
682 |
//BOOTSRAP TRAIN_TEST |
|
|
683 |
arffTestFileWithExtractedModelFeatures = arffTestFileWithExtractedModelFeatures.replace(".RH_features.arff", ".BS_features.arff"); |
|
|
684 |
Weka_module.evaluationPerformancesResultsObject eproBSTrainTest = new Weka_module.evaluationPerformancesResultsObject(); |
|
|
685 |
try { |
|
|
686 |
alROCs = new ArrayList<>(); |
|
|
687 |
Weka_module weka4 = new Weka_module(); |
|
|
688 |
weka4.setARFFfile(trainFileName.replace("data_to_train.csv", "all_data.arff")); |
|
|
689 |
weka4.setDataFromArff(); |
|
|
690 |
|
|
|
691 |
// create compatible file |
|
|
692 |
if (Main.combineModels) { |
|
|
693 |
//combined model contains the filters, we need to keep the same |
|
|
694 |
//features indexes as the b.featureSelection.infoGain.arff |
|
|
695 |
weka4.extractFeaturesFromArffFileBasedOnSelectedFeatures(weka.myData, |
|
|
696 |
weka4.myData, arffTestFileWithExtractedModelFeatures); |
|
|
697 |
} else { |
|
|
698 |
weka4.extractFeaturesFromTestFileBasedOnModel(modelFilename + ".model", |
|
|
699 |
weka4.myData, arffTestFileWithExtractedModelFeatures); |
|
|
700 |
} |
|
|
701 |
|
|
|
702 |
// reload compatible file in weka2 |
|
|
703 |
weka4 = new Weka_module(); |
|
|
704 |
weka4.setARFFfile(arffTestFileWithExtractedModelFeatures); |
|
|
705 |
weka4.setDataFromArff(); |
|
|
706 |
if (classification) { |
|
|
707 |
if (!Main.combineModels) { |
|
|
708 |
weka4.myData = weka4.extractFeaturesFromDatasetBasedOnModel(cr.model, weka4.myData); |
|
|
709 |
} |
|
|
710 |
System.out.println("Bootstrap evaluation on TRAIN AND TEST sets of " + co.classifier + " " + co.options |
|
|
711 |
+ " optimized by " + co.optimizer); |
|
|
712 |
pw.println("\n#Bootstrap evaluation performance on TRAIN AND TEST set, " |
|
|
713 |
+ Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds"); |
|
|
714 |
for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) { |
|
|
715 |
Weka_module.ClassificationResultsObject cro |
|
|
716 |
= (Weka_module.ClassificationResultsObject) weka4.trainClassifierBootstrap(co.classifier, co.options, |
|
|
717 |
null, classification, i); |
|
|
718 |
eproBSTrainTest.alAUCs.add(Double.valueOf(cro.AUC)); |
|
|
719 |
eproBSTrainTest.alpAUCs.add(Double.valueOf(cro.pAUC)); |
|
|
720 |
eproBSTrainTest.alAUPRCs.add(Double.valueOf(cro.AUPRC)); |
|
|
721 |
eproBSTrainTest.alACCs.add(Double.valueOf(cro.ACC)); |
|
|
722 |
eproBSTrainTest.alSEs.add(Double.valueOf(cro.TPR)); |
|
|
723 |
eproBSTrainTest.alSPs.add(Double.valueOf(cro.TNR)); |
|
|
724 |
eproBSTrainTest.alMCCs.add(Double.valueOf(cro.MCC)); |
|
|
725 |
eproBSTrainTest.alMAEs.add(Double.valueOf(cro.MAE)); |
|
|
726 |
eproBSTrainTest.alBERs.add(Double.valueOf(cro.BER)); |
|
|
727 |
alROCs.add(cro); |
|
|
728 |
} |
|
|
729 |
eproBSTrainTest.computeMeans(); |
|
|
730 |
alMCCs.add(Double.valueOf(eproBSTrainTest.meanMCCs)); |
|
|
731 |
alMAEs.add(Double.valueOf(eproBSTrainTest.meanMAEs)); |
|
|
732 |
System.out.println(eproBSTrainTest.toStringClassificationDetails()); |
|
|
733 |
pw.println(eproBSTrainTest.toStringClassificationDetails().replace("[score_training] ", "")); |
|
|
734 |
|
|
|
735 |
//632+ rule |
|
|
736 |
System.out.println("Bootstrap .632+ rule calculated on TRAIN AND TEST set of " + co.classifier + " " + co.options |
|
|
737 |
+ " optimized by " + co.optimizer + "..."); |
|
|
738 |
pw.println("\n#Bootstrap .632+ rule calculated on TRAIN AND TEST set, " |
|
|
739 |
+ Main.bootstrapAndRepeatedHoldoutFolds + " folds with random seeds"); |
|
|
740 |
|
|
|
741 |
bootstrapTrain632plus = weka4.trainClassifierBootstrap632plus(co.classifier, co.options, |
|
|
742 |
null); |
|
|
743 |
System.out.println(df.format(bootstrapTrain632plus)); |
|
|
744 |
pw.println(df.format(bootstrapTrain632plus)); |
|
|
745 |
|
|
|
746 |
if (Main.ROCcurves) { |
|
|
747 |
rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc.png"); |
|
|
748 |
} |
|
|
749 |
} else { |
|
|
750 |
if (!Main.combineModels) { |
|
|
751 |
weka4.myData = weka4.extractFeaturesFromDatasetBasedOnModel(rr.model, weka4.myData); |
|
|
752 |
} |
|
|
753 |
System.out.println("Bootstrap evaluation on TRAIN AND TEST sets of " + ro.classifier + " " |
|
|
754 |
+ ro.options + "optimized by " + ro.optimizer.toUpperCase()); |
|
|
755 |
pw.println("\n#Bootstrap evaluation performance on TRAIN AND TEST set, " |
|
|
756 |
+ Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds"); |
|
|
757 |
for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) { |
|
|
758 |
Weka_module.RegressionResultsObject rro |
|
|
759 |
= (Weka_module.RegressionResultsObject) weka4.trainClassifierBootstrap(ro.classifier, ro.options, |
|
|
760 |
null, classification, i); |
|
|
761 |
eproBSTrainTest.alCCs.add(Double.valueOf(rro.CC)); |
|
|
762 |
eproBSTrainTest.alMAEs.add(Double.valueOf(rro.MAE)); |
|
|
763 |
eproBSTrainTest.alRMSEs.add(Double.valueOf(rro.RMSE)); |
|
|
764 |
eproBSTrainTest.alRAEs.add(Double.valueOf(rro.RAE)); |
|
|
765 |
eproBSTrainTest.alRRSEs.add(Double.valueOf(rro.RRSE)); |
|
|
766 |
} |
|
|
767 |
eproBSTrainTest.computeMeans(); |
|
|
768 |
alCCs.add(Double.valueOf(eproBSTrainTest.meanCCs)); |
|
|
769 |
alMAEs.add(Double.valueOf(eproBSTrainTest.meanMAEs)); |
|
|
770 |
System.out.println(eproBSTrainTest.toStringRegressionDetails()); |
|
|
771 |
pw.println(eproBSTrainTest.toStringRegressionDetails().replace("[score_training] ", "")); |
|
|
772 |
} |
|
|
773 |
eproBSTrainTest.computeMeans(); |
|
|
774 |
} catch (Exception e) { |
|
|
775 |
if (Main.debug) { |
|
|
776 |
e.printStackTrace(); |
|
|
777 |
} |
|
|
778 |
} |
|
|
779 |
|
|
|
780 |
//remove test file arff once done |
|
|
781 |
new File(arffTestFileWithExtractedModelFeatures).delete(); |
|
|
782 |
} |
|
|
783 |
} catch (Exception e) { |
|
|
784 |
if (Main.debug) { |
|
|
785 |
e.printStackTrace(); |
|
|
786 |
} |
|
|
787 |
} |
|
|
788 |
pw.flush(); |
|
|
789 |
|
|
|
790 |
// show average metrics and standard deviation |
|
|
791 |
if (classification) { |
|
|
792 |
pw.println("\n# Average MCC: " + utils.getMean(alMCCs) |
|
|
793 |
+ "\t(" + utils.getStandardDeviation(alMCCs) + ")"); |
|
|
794 |
System.out.println("\n# Average MCC: " + utils.getMean(alMCCs)); |
|
|
795 |
pw.println("# Average MAE: " + utils.getMean(alMAEs) |
|
|
796 |
+ "\t(" + utils.getStandardDeviation(alMAEs) + ")"); |
|
|
797 |
System.out.println("# Average MAE: " + utils.getMean(alMAEs)); |
|
|
798 |
} else { |
|
|
799 |
pw.println("\n# Average CC: " + utils.getMean(alCCs) |
|
|
800 |
+ "\t(" + utils.getStandardDeviation(alCCs) + ")"); |
|
|
801 |
System.out.println("\n# Average CC: " + utils.getMean(alCCs)); |
|
|
802 |
pw.println("# Average MAE: " + utils.getMean(alMAEs) |
|
|
803 |
+ "\t(" + utils.getStandardDeviation(alMAEs) + ")"); |
|
|
804 |
System.out.println("# Average MAE: " + utils.getMean(alMAEs)); |
|
|
805 |
} |
|
|
806 |
|
|
|
807 |
//output features |
|
|
808 |
if (classification) { |
|
|
809 |
try { |
|
|
810 |
pw.print("\n# Selected Attributes (Total attributes:" + cr.numberOfFeatures + "). " |
|
|
811 |
+ "Occurrences are shown if you chose combined model\n"); |
|
|
812 |
pw.print(cr.features); |
|
|
813 |
pw.println("\n# Attribute ranking by merit calculated by information gain"); |
|
|
814 |
pw.print(cr.getFeatureRankingResults()); |
|
|
815 |
|
|
|
816 |
} catch (Exception e) { |
|
|
817 |
e.printStackTrace(); |
|
|
818 |
} |
|
|
819 |
} else { |
|
|
820 |
try { |
|
|
821 |
pw.print("\n# Selected Attributes\t(Total attributes:" + rr.numberOfFeatures + "). " |
|
|
822 |
+ "Occurrences are shown if you chose combined model\n"); |
|
|
823 |
pw.print(rr.features); |
|
|
824 |
pw.println("\n# Attribute ranking by merit calculated by RELIEFF"); |
|
|
825 |
pw.print(rr.getFeatureRankingResults()); |
|
|
826 |
|
|
|
827 |
} catch (Exception e) { |
|
|
828 |
e.printStackTrace(); |
|
|
829 |
} |
|
|
830 |
} |
|
|
831 |
pw.flush(); |
|
|
832 |
|
|
|
833 |
//retrieve correlated features |
|
|
834 |
// do not retreive correlated features if we are already computing a model for the long signature |
|
|
835 |
if (Main.retrieveCorrelatedGenes && !correlatedFeaturesMode) { |
|
|
836 |
if (new File(trainFileName).exists()) { |
|
|
837 |
System.out.print("Search correlated features (spearman)..."); |
|
|
838 |
pw.println("\n# Correlated features (Spearman)"); |
|
|
839 |
pw.println("FeatureInSignature\tSpearmanCorrelationScore\tCorrelatedFeature"); |
|
|
840 |
TreeMap<String, Double> tmsCorrelatedgenes |
|
|
841 |
= RetreiveCorrelatedGenes.spearmanCorrelation(modelFilename + ".train_features.csv", trainFileName); |
|
|
842 |
for (String correlation : tmsCorrelatedgenes.keySet()) { |
|
|
843 |
pw.println(correlation); |
|
|
844 |
} |
|
|
845 |
if (tmsCorrelatedgenes.isEmpty()) { |
|
|
846 |
pw.println("#nothing found !"); |
|
|
847 |
} |
|
|
848 |
System.out.println("[done]"); |
|
|
849 |
|
|
|
850 |
System.out.print("Search correlated features (pearson)..."); |
|
|
851 |
pw.println("\n# Correlated features (Pearson)"); |
|
|
852 |
pw.println("FeatureInSignature\tPearsonCorrelationScore\tCorrelatedFeature"); |
|
|
853 |
TreeMap<String, Double> tmpCorrelatedgenes |
|
|
854 |
= RetreiveCorrelatedGenes.pearsonCorrelation(modelFilename + ".train_features.csv", trainFileName); |
|
|
855 |
for (String correlation : tmpCorrelatedgenes.keySet()) { |
|
|
856 |
pw.println(correlation); |
|
|
857 |
} |
|
|
858 |
if (tmpCorrelatedgenes.isEmpty()) { |
|
|
859 |
pw.println("#nothing found !"); |
|
|
860 |
} |
|
|
861 |
System.out.println("[done]"); |
|
|
862 |
} else { |
|
|
863 |
System.out.println("Feature file " + trainFileName + " not found. Unable to calculate correlated genes"); |
|
|
864 |
} |
|
|
865 |
pw.flush(); |
|
|
866 |
|
|
|
867 |
//retreive rankings |
|
|
868 |
String ranking = ""; |
|
|
869 |
if (classification) { |
|
|
870 |
ranking = weka.featureRankingForClassification(trainFileName.replace("csv", "arff")); |
|
|
871 |
} else { |
|
|
872 |
ranking = weka.featureRankingForRegression(trainFileName.replace("csv", "arff")); |
|
|
873 |
} |
|
|
874 |
String lines[] = ranking.split("\n"); |
|
|
875 |
HashMap<String, ArrayList<RankerObject>> hmRanks = new HashMap(); |
|
|
876 |
try { |
|
|
877 |
for (String s : lines) { |
|
|
878 |
s = s.replaceAll(" +", " "); |
|
|
879 |
if (!s.startsWith("\t") && !s.trim().isEmpty() && s.trim().split(" ").length == 3) { |
|
|
880 |
RankerObject rankero = new RankerObject(s.trim()); |
|
|
881 |
if (hmRanks.containsKey(rankero.roundedScore)) { |
|
|
882 |
ArrayList<RankerObject> alRankero = hmRanks.get(rankero.roundedScore); |
|
|
883 |
alRankero.add(rankero); |
|
|
884 |
hmRanks.put(rankero.roundedScore, alRankero); |
|
|
885 |
} else { |
|
|
886 |
ArrayList<RankerObject> alRankero = new ArrayList<>(); |
|
|
887 |
alRankero.add(rankero); |
|
|
888 |
hmRanks.put(rankero.roundedScore, alRankero); |
|
|
889 |
} |
|
|
890 |
} |
|
|
891 |
} |
|
|
892 |
} catch (Exception e) { |
|
|
893 |
if (Main.debug) { |
|
|
894 |
e.printStackTrace(); |
|
|
895 |
} |
|
|
896 |
} |
|
|
897 |
|
|
|
898 |
if (Main.retreiveCorrelatedGenesByRankingScore) { |
|
|
899 |
System.out.print("Search similar ranking scores (infogain for classification or relieFf) in the original dataset..."); |
|
|
900 |
|
|
|
901 |
pw.println("\n# Similar ranking score (maximal difference: " + Main.maxRankingScoreDifference + ")"); |
|
|
902 |
pw.println("FeatureInSignature\tRankingScore\tFeatureInDataset\tRankingScore"); |
|
|
903 |
String rankedFeaturesSign[]; |
|
|
904 |
if (classification) { |
|
|
905 |
rankedFeaturesSign = cr.getFeatureRankingResults().split("\n"); |
|
|
906 |
} else { |
|
|
907 |
rankedFeaturesSign = rr.getFeatureRankingResults().split("\n"); |
|
|
908 |
} |
|
|
909 |
for (String featureSign : rankedFeaturesSign) { |
|
|
910 |
String featureSignIG = featureSign.split("\t")[0]; |
|
|
911 |
String featureSignIGrounded = df.format(Double.valueOf(featureSignIG)); |
|
|
912 |
if (hmRanks.containsKey(featureSignIGrounded)) { |
|
|
913 |
for (RankerObject alio : hmRanks.get(featureSignIGrounded)) { |
|
|
914 |
if (!featureSign.contains(alio.feature) && !alio.feature.equals(Main.mergingID)) { |
|
|
915 |
//max difference between infogains: 0.005 |
|
|
916 |
if (Math.abs(Double.parseDouble(featureSignIG) - Double.parseDouble(alio.infogain)) |
|
|
917 |
<= Main.maxRankingScoreDifference) { |
|
|
918 |
pw.println(featureSign.split("\t")[1] + "\t" |
|
|
919 |
+ featureSignIG + "\t" |
|
|
920 |
+ alio.feature + "\t" |
|
|
921 |
+ alio.infogain); |
|
|
922 |
} |
|
|
923 |
|
|
|
924 |
} |
|
|
925 |
} |
|
|
926 |
} |
|
|
927 |
} |
|
|
928 |
System.out.println("[done]"); |
|
|
929 |
} |
|
|
930 |
|
|
|
931 |
//close file |
|
|
932 |
pw.println("\n\n## End of file ##"); |
|
|
933 |
pw.close(); |
|
|
934 |
|
|
|
935 |
//export enriched signature |
|
|
936 |
LinkedHashMap<String, String> lhmCorrFeaturesNames = new LinkedHashMap<>();//signature + correlated features |
|
|
937 |
LinkedHashMap<String, String> lhmFeaturesNames = new LinkedHashMap<>();//signature only |
|
|
938 |
try { |
|
|
939 |
br = new BufferedReader(new FileReader(modelFilename + ".details.txt")); |
|
|
940 |
String line = ""; |
|
|
941 |
//go to selected attributes |
|
|
942 |
while (!line.startsWith("# Attribute ranking by")) { |
|
|
943 |
line = br.readLine(); |
|
|
944 |
} |
|
|
945 |
line = br.readLine(); |
|
|
946 |
//add attributes to hashmap |
|
|
947 |
while (!line.startsWith("#")) { |
|
|
948 |
if (!line.isEmpty()) { |
|
|
949 |
lhmCorrFeaturesNames.put(line.split("\t")[1].trim(), ""); |
|
|
950 |
lhmFeaturesNames.put(line.split("\t")[1].trim(), ""); |
|
|
951 |
} |
|
|
952 |
line = br.readLine(); |
|
|
953 |
} |
|
|
954 |
//go to spearman correlated attributes |
|
|
955 |
while (!line.startsWith("FeatureInSignature")) { |
|
|
956 |
line = br.readLine(); |
|
|
957 |
} |
|
|
958 |
line = br.readLine(); |
|
|
959 |
//add attributes to hashmap |
|
|
960 |
while (!line.startsWith("#")) { |
|
|
961 |
if (!line.isEmpty()) { |
|
|
962 |
lhmCorrFeaturesNames.put(line.split("\t")[2].trim(), ""); |
|
|
963 |
} |
|
|
964 |
line = br.readLine(); |
|
|
965 |
} |
|
|
966 |
//go to pearson correlated attributes |
|
|
967 |
while (!line.startsWith("FeatureInSignature")) { |
|
|
968 |
line = br.readLine(); |
|
|
969 |
} |
|
|
970 |
line = br.readLine(); |
|
|
971 |
//add attributes to hashmap |
|
|
972 |
while (!line.startsWith("#")) { |
|
|
973 |
if (!line.isEmpty()) { |
|
|
974 |
lhmCorrFeaturesNames.put(line.split("\t")[2].trim(), ""); |
|
|
975 |
} |
|
|
976 |
line = br.readLine(); |
|
|
977 |
} |
|
|
978 |
if (Main.retreiveCorrelatedGenesByRankingScore) { |
|
|
979 |
//go to infogain correlated attributes |
|
|
980 |
while (!line.startsWith("FeatureInSignature")) { |
|
|
981 |
line = br.readLine(); |
|
|
982 |
} |
|
|
983 |
line = br.readLine(); |
|
|
984 |
//add attributes to hashmap |
|
|
985 |
while (br.ready()) { |
|
|
986 |
if (!line.isEmpty()) { |
|
|
987 |
lhmCorrFeaturesNames.put(line.split("\t")[0].trim(), ""); |
|
|
988 |
lhmCorrFeaturesNames.put(line.split("\t")[2].trim(), ""); |
|
|
989 |
} |
|
|
990 |
line = br.readLine(); |
|
|
991 |
} |
|
|
992 |
} |
|
|
993 |
br.close(); |
|
|
994 |
//write correlated feature file from training file |
|
|
995 |
correlatedFeatures = writeFeaturesFile(lhmCorrFeaturesNames, trainFileName, |
|
|
996 |
classification, modelFilename + ".train_corrFeatures.csv"); |
|
|
997 |
if (Main.doSampling) { |
|
|
998 |
//write correlated feature file from all data file |
|
|
999 |
//if we have done a sampling, then we can't go from trainFeaturesFile |
|
|
1000 |
//but to allFeaturesFile, which contain test data |
|
|
1001 |
correlatedFeatures = writeFeaturesFile(lhmCorrFeaturesNames, trainFileName.replace("data_to_train", "all_data"), |
|
|
1002 |
classification, modelFilename + ".all_corrFeatures.csv"); |
|
|
1003 |
|
|
|
1004 |
//short signature |
|
|
1005 |
writeFeaturesFile(lhmFeaturesNames, trainFileName.replace("data_to_train", "all_data"), |
|
|
1006 |
classification, modelFilename + ".all_features.csv"); |
|
|
1007 |
System.out.println(""); |
|
|
1008 |
} |
|
|
1009 |
|
|
|
1010 |
} catch (Exception e) { |
|
|
1011 |
e.printStackTrace(); |
|
|
1012 |
} |
|
|
1013 |
|
|
|
1014 |
} |
|
|
1015 |
// delete useless files |
|
|
1016 |
if (Main.doSampling) { |
|
|
1017 |
new File(modelFilename + ".train_features.csv").delete(); |
|
|
1018 |
new File(modelFilename + ".train_corrFeatures.csv").delete(); |
|
|
1019 |
} |
|
|
1020 |
new File(modelFilename + ".test_features.arff").delete(); |
|
|
1021 |
new File(modelFilename + ".RH_features.arff").delete(); |
|
|
1022 |
new File(modelFilename + ".BS_features.arff").delete(); |
|
|
1023 |
|
|
|
1024 |
} |
|
|
1025 |
|
|
|
1026 |
/** |
|
|
1027 |
* |
|
|
1028 |
* @param lhm feature names in order |
|
|
1029 |
* @param originFile the training file or all data file |
|
|
1030 |
* @param classification if we are doing a classification |
|
|
1031 |
* @param outfile outfile name |
|
|
1032 |
*/ |
|
|
1033 |
private String writeFeaturesFile(LinkedHashMap<String, String> lhm, String originFile, boolean classification, String outfile) { |
|
|
1034 |
//find columns indices |
|
|
1035 |
String featuresSeparatedByCommas = "1"; |
|
|
1036 |
try { |
|
|
1037 |
BufferedReader br = new BufferedReader(new FileReader(originFile)); |
|
|
1038 |
String header = br.readLine(); |
|
|
1039 |
String features[] = header.split(utils.detectSeparator(originFile)); |
|
|
1040 |
for (int i = 0; i < features.length; i++) { |
|
|
1041 |
String feature = features[i]; |
|
|
1042 |
if (lhm.containsKey(feature)) { |
|
|
1043 |
featuresSeparatedByCommas += "," + (i + 1); |
|
|
1044 |
} |
|
|
1045 |
} |
|
|
1046 |
featuresSeparatedByCommas += "," + features.length; |
|
|
1047 |
} catch (Exception e) { |
|
|
1048 |
e.printStackTrace(); |
|
|
1049 |
} |
|
|
1050 |
|
|
|
1051 |
//extract columns using weka filter |
|
|
1052 |
Weka_module weka2 = new Weka_module(); |
|
|
1053 |
weka2.setARFFfile(originFile.replace("csv", "arff")); |
|
|
1054 |
weka2.setDataFromArff(); |
|
|
1055 |
weka2.saveFilteredDataToCSV(featuresSeparatedByCommas, |
|
|
1056 |
classification, outfile); |
|
|
1057 |
return featuresSeparatedByCommas; |
|
|
1058 |
|
|
|
1059 |
} |
|
|
1060 |
|
|
|
1061 |
/** |
|
|
1062 |
* initialize weka |
|
|
1063 |
* |
|
|
1064 |
* @param infile |
|
|
1065 |
* @param classification |
|
|
1066 |
*/ |
|
|
1067 |
private static void init(String infile, boolean classification) { |
|
|
1068 |
//convert csv to arff |
|
|
1069 |
if (infile.endsWith(".csv")) { |
|
|
1070 |
weka.setCSVFile(new File(infile)); |
|
|
1071 |
weka.csvToArff(classification); |
|
|
1072 |
} else { |
|
|
1073 |
weka.setARFFfile(infile.replace(".csv", ".arff")); |
|
|
1074 |
} |
|
|
1075 |
|
|
|
1076 |
//set local variable of weka object from ARFFfile |
|
|
1077 |
weka.setDataFromArff(); |
|
|
1078 |
weka.myData = weka.convertStringsToNominal(weka.myData); |
|
|
1079 |
// // check if class has numeric values, hence regression, instead of nominal class (classification) |
|
|
1080 |
classification = weka.isClassification(); |
|
|
1081 |
} |
|
|
1082 |
|
|
|
1083 |
/** |
|
|
1084 |
* calculate mean and standard deviation of an array of doubles |
|
|
1085 |
* |
|
|
1086 |
* @param al |
|
|
1087 |
* @return |
|
|
1088 |
*/ |
|
|
1089 |
private static String getMeanAndStandardDeviation(ArrayList<Double> al) { |
|
|
1090 |
|
|
|
1091 |
double d[] = new double[al.size()]; |
|
|
1092 |
for (int i = 0; i < al.size(); i++) { |
|
|
1093 |
d[i] = (double) al.get(i); |
|
|
1094 |
} |
|
|
1095 |
|
|
|
1096 |
StandardDeviation sd = new StandardDeviation(); |
|
|
1097 |
Mean m = new Mean(); |
|
|
1098 |
|
|
|
1099 |
return df.format(m.evaluate(d)) + " (" + df.format(sd.evaluate(d)) + ")"; |
|
|
1100 |
} |
|
|
1101 |
|
|
|
1102 |
/** |
|
|
1103 |
* classification object |
|
|
1104 |
*/ |
|
|
1105 |
public static class classificationObject { |
|
|
1106 |
|
|
|
1107 |
public ArrayList<String> featureList = new ArrayList<>(); |
|
|
1108 |
public String featuresSeparatedByCommas = ""; |
|
|
1109 |
public String optimizer = ""; |
|
|
1110 |
public String mode = ""; |
|
|
1111 |
public String classifier = ""; |
|
|
1112 |
public String options = ""; |
|
|
1113 |
public String identifier = ""; |
|
|
1114 |
public TreeMap<Integer, Integer> tmFeatures; |
|
|
1115 |
public HashMap<String, String> hmValues = new HashMap<>(); //Column name, value |
|
|
1116 |
|
|
|
1117 |
public classificationObject() { |
|
|
1118 |
} |
|
|
1119 |
|
|
|
1120 |
public classificationObject(String line) { |
|
|
1121 |
|
|
|
1122 |
identifier = line.split("\t")[hmResultsHeaderNames.get("ID")]; |
|
|
1123 |
classifier = line.split("\t")[hmResultsHeaderNames.get("classifier")]; |
|
|
1124 |
options = line.split("\t")[hmResultsHeaderNames.get("Options")]; |
|
|
1125 |
optimizer = line.split("\t")[hmResultsHeaderNames.get("OptimizedValue")]; |
|
|
1126 |
mode = line.split("\t")[hmResultsHeaderNames.get("SearchMode")]; |
|
|
1127 |
|
|
|
1128 |
featureList.addAll(Arrays.asList(line.split("\t")[hmResultsHeaderNames.get("AttributeList")].split(","))); |
|
|
1129 |
featuresSeparatedByCommas = line.split("\t")[hmResultsHeaderNames.get("AttributeList")]; |
|
|
1130 |
String s[] = line.split("\t"); |
|
|
1131 |
for (int i = 0; i < s.length; i++) { |
|
|
1132 |
hmValues.put(hmResultsHeaderIndexes.get(i), s[i]); |
|
|
1133 |
} |
|
|
1134 |
|
|
|
1135 |
} |
|
|
1136 |
|
|
|
1137 |
/** |
|
|
1138 |
* for combined models |
|
|
1139 |
* |
|
|
1140 |
* @param alBestClassifiers |
|
|
1141 |
* @param NumberOfTopModels |
|
|
1142 |
*/ |
|
|
1143 |
private void buildVoteClassifier(ArrayList<Object> alBestClassifiers) { |
|
|
1144 |
classifier = "meta.Vote"; //Combination rule: average of probabilities |
|
|
1145 |
options = "-S 1 -R " + Main.combinationRule; |
|
|
1146 |
|
|
|
1147 |
tmFeatures = new TreeMap<>(); |
|
|
1148 |
for (Object c : alBestClassifiers) { |
|
|
1149 |
classificationObject co = (classificationObject) c; |
|
|
1150 |
//create filteredclassifier with selected attributes |
|
|
1151 |
String filteredClassifierOptions = "-B \"weka.classifiers.meta.FilteredClassifier -F \\\"weka.filters.unsupervised.attribute.Remove -V -R " |
|
|
1152 |
+ co.featuresSeparatedByCommas.substring(2) + "\\\""; |
|
|
1153 |
//add classifier and its options |
|
|
1154 |
String classif = "-W weka.classifiers." + co.classifier + " --"; |
|
|
1155 |
String classifOptions = co.options.replace("\\", "\\\\").replace("\"", "\\\"") + "\""; |
|
|
1156 |
options += " " + filteredClassifierOptions + " " + classif + " " + classifOptions; |
|
|
1157 |
|
|
|
1158 |
//get all features seen and get their number of views |
|
|
1159 |
for (String f : co.featuresSeparatedByCommas.split(",")) { |
|
|
1160 |
if (tmFeatures.containsKey(Integer.valueOf(f))) { |
|
|
1161 |
int i = tmFeatures.get(Integer.valueOf(f)); |
|
|
1162 |
i++; |
|
|
1163 |
tmFeatures.put(Integer.valueOf(f), i); |
|
|
1164 |
} else { |
|
|
1165 |
tmFeatures.put(Integer.valueOf(f), 1); |
|
|
1166 |
} |
|
|
1167 |
} |
|
|
1168 |
} |
|
|
1169 |
//set features lists |
|
|
1170 |
for (Integer f : tmFeatures.keySet()) { |
|
|
1171 |
featureList.add(f.toString()); |
|
|
1172 |
featuresSeparatedByCommas += "," + f; |
|
|
1173 |
} |
|
|
1174 |
featuresSeparatedByCommas = featuresSeparatedByCommas.substring(1); |
|
|
1175 |
|
|
|
1176 |
//set other variables |
|
|
1177 |
optimizer = "COMB"; |
|
|
1178 |
mode = Main.numberOfBestModels + "_" + Main.bestModelsSortingMetric + "_" + Main.bestModelsSortingMetricThreshold; |
|
|
1179 |
} |
|
|
1180 |
|
|
|
1181 |
/** |
|
|
1182 |
* printable version of options |
|
|
1183 |
* |
|
|
1184 |
* @return |
|
|
1185 |
*/ |
|
|
1186 |
private String printOptions() { |
|
|
1187 |
if (classifier.contains("meta.Vote")) { |
|
|
1188 |
return options.substring(0, options.indexOf("-B")).replace(" ", ""); |
|
|
1189 |
} else { |
|
|
1190 |
return options.replace(" ", "").replace("\\", "").replace("\"", ""); |
|
|
1191 |
} |
|
|
1192 |
} |
|
|
1193 |
|
|
|
1194 |
} |
|
|
1195 |
|
|
|
1196 |
public static class regressionObject { |
|
|
1197 |
|
|
|
1198 |
public ArrayList<String> featureList = new ArrayList<>(); |
|
|
1199 |
public String featuresSeparatedByCommas = ""; |
|
|
1200 |
public String classifier; |
|
|
1201 |
public String optimizer; |
|
|
1202 |
public String options; |
|
|
1203 |
public String mode; |
|
|
1204 |
public String identifier; |
|
|
1205 |
public TreeMap<Integer, Integer> tmFeatures; |
|
|
1206 |
public HashMap<String, String> hmValues = new HashMap<>(); //Column name, value |
|
|
1207 |
|
|
|
1208 |
public regressionObject() { |
|
|
1209 |
} |
|
|
1210 |
|
|
|
1211 |
public regressionObject(String line) { |
|
|
1212 |
identifier = line.split("\t")[hmResultsHeaderNames.get("ID")]; |
|
|
1213 |
classifier = line.split("\t")[hmResultsHeaderNames.get("classifier")]; |
|
|
1214 |
options = line.split("\t")[hmResultsHeaderNames.get("Options")]; |
|
|
1215 |
optimizer = line.split("\t")[hmResultsHeaderNames.get("OptimizedValue")]; |
|
|
1216 |
mode = line.split("\t")[hmResultsHeaderNames.get("SearchMode")]; |
|
|
1217 |
featureList.addAll(Arrays.asList(line.split("\t")[hmResultsHeaderNames.get("AttributeList")].split(","))); |
|
|
1218 |
featuresSeparatedByCommas = line.split("\t")[hmResultsHeaderNames.get("AttributeList")]; |
|
|
1219 |
|
|
|
1220 |
if (options.startsWith("\"")) { |
|
|
1221 |
options = options.substring(1); //remove first " |
|
|
1222 |
options = options.replace("\\\"\"", "\\\""); // replace \"" by \" |
|
|
1223 |
options = options.replace("\"\"\"", "\""); // replace """ by " |
|
|
1224 |
options = options.replace("\"\"weka", "\"weka"); // replace "" by " |
|
|
1225 |
} |
|
|
1226 |
|
|
|
1227 |
String s[] = line.split("\t"); |
|
|
1228 |
for (int i = 0; i < s.length; i++) { |
|
|
1229 |
hmValues.put(hmResultsHeaderIndexes.get(i), s[i]); |
|
|
1230 |
} |
|
|
1231 |
} |
|
|
1232 |
|
|
|
1233 |
/** |
|
|
1234 |
* for combined models |
|
|
1235 |
* |
|
|
1236 |
* @param alBestClassifiers |
|
|
1237 |
* @param cc |
|
|
1238 |
* @param NumberOfTopModels |
|
|
1239 |
*/ |
|
|
1240 |
private void buildVoteClassifier(ArrayList<Object> alBestClassifiers) { |
|
|
1241 |
classifier = "meta.Vote"; //Combination rule: average of probabilities |
|
|
1242 |
options = "-S 1 -R " + Main.combinationRule; |
|
|
1243 |
|
|
|
1244 |
tmFeatures = new TreeMap<>(); |
|
|
1245 |
int cpt = 0; |
|
|
1246 |
for (Object r : alBestClassifiers) { |
|
|
1247 |
regressionObject ro = (regressionObject) r; |
|
|
1248 |
cpt++; |
|
|
1249 |
//create filteredclassifier with selected attributes |
|
|
1250 |
String filteredClassifierOptions = "-B \"weka.classifiers.meta.FilteredClassifier -F \\\"weka.filters.unsupervised.attribute.Remove -V -R " |
|
|
1251 |
+ ro.featuresSeparatedByCommas.substring(2) + "\\\""; |
|
|
1252 |
//add classifier and its options |
|
|
1253 |
String classif = "-W weka.classifiers." + ro.classifier + " --"; |
|
|
1254 |
String classifOptions = ro.options.replace("\\", "\\\\").replace("\"", "\\\"") + "\""; |
|
|
1255 |
options += " " + filteredClassifierOptions + " " + classif + " " + classifOptions; |
|
|
1256 |
|
|
|
1257 |
//get all features seen and get their number of views |
|
|
1258 |
for (String f : ro.featuresSeparatedByCommas.split(",")) { |
|
|
1259 |
if (tmFeatures.containsKey(Integer.valueOf(f))) { |
|
|
1260 |
int i = tmFeatures.get(Integer.valueOf(f)); |
|
|
1261 |
i++; |
|
|
1262 |
tmFeatures.put(Integer.valueOf(f), i); |
|
|
1263 |
} else { |
|
|
1264 |
tmFeatures.put(Integer.valueOf(f), 1); |
|
|
1265 |
} |
|
|
1266 |
} |
|
|
1267 |
|
|
|
1268 |
} |
|
|
1269 |
|
|
|
1270 |
//set features lists |
|
|
1271 |
for (Integer f : tmFeatures.keySet()) { |
|
|
1272 |
featureList.add(f.toString()); |
|
|
1273 |
featuresSeparatedByCommas += "," + f; |
|
|
1274 |
} |
|
|
1275 |
featuresSeparatedByCommas = featuresSeparatedByCommas.substring(1); |
|
|
1276 |
|
|
|
1277 |
//set other variables |
|
|
1278 |
optimizer = "COMB"; |
|
|
1279 |
mode = Main.numberOfBestModels + "_" + Main.bestModelsSortingMetric + "_" + Main.bestModelsSortingMetricThreshold; |
|
|
1280 |
} |
|
|
1281 |
|
|
|
1282 |
/** |
|
|
1283 |
* printable version of options |
|
|
1284 |
* |
|
|
1285 |
* @return |
|
|
1286 |
*/ |
|
|
1287 |
private String printOptions() { |
|
|
1288 |
if (classifier.contains("meta.Vote")) { |
|
|
1289 |
return options.substring(0, options.indexOf("-B")).replace(" ", ""); |
|
|
1290 |
} else { |
|
|
1291 |
return options.replace(" ", "").replace("\\", "").replace("\"", ""); |
|
|
1292 |
} |
|
|
1293 |
} |
|
|
1294 |
|
|
|
1295 |
} |
|
|
1296 |
|
|
|
1297 |
private static class RankerObject { |
|
|
1298 |
|
|
|
1299 |
public String infogain; |
|
|
1300 |
public String roundedScore; |
|
|
1301 |
public String feature; |
|
|
1302 |
|
|
|
1303 |
public RankerObject() { |
|
|
1304 |
|
|
|
1305 |
} |
|
|
1306 |
|
|
|
1307 |
private RankerObject(String s) { |
|
|
1308 |
infogain = s.split(" ")[0]; |
|
|
1309 |
roundedScore = df.format(Double.valueOf(s.split(" ")[0])); |
|
|
1310 |
feature = s.split(" ")[2]; |
|
|
1311 |
} |
|
|
1312 |
} |
|
|
1313 |
|
|
|
1314 |
} |