Switch to unified view

a b/src/biodiscml/BestModelSelectionAndReport.java
1
/*
2
 * Select only models getting specific min MCC or max RMSE
3
 * Select all features from selected models
4
 * Retrain a model (best classifier of all tested) with the unique selected features with LOOCV 75/25.
5
 * Do it 10 times with various seeds, report the average scores (ex: AUC)
6
 * Explore biology behind the set of selected features
7
8
 */
9
package biodiscml;
10
11
import java.io.BufferedReader;
12
import java.io.File;
13
import java.io.FileReader;
14
import java.io.FileWriter;
15
import java.io.PrintWriter;
16
import java.text.DecimalFormat;
17
import java.text.DecimalFormatSymbols;
18
import java.util.ArrayList;
19
import java.util.Arrays;
20
import java.util.Collections;
21
import java.util.HashMap;
22
import java.util.LinkedHashMap;
23
import java.util.TreeMap;
24
import org.apache.commons.math3.stat.descriptive.moment.Mean;
25
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
26
import utils.UpSetR;
27
import utils.Weka_module;
28
import utils.utils;
29
import weka.core.SerializationHelper;
30
31
/**
32
 *
33
 * @author Mickael
34
 */
35
public class BestModelSelectionAndReport {
36
37
    public static String wd = Main.wd;
38
    public static Weka_module weka = new Weka_module();
39
    public static HashMap<String, Integer> hmResultsHeaderNames = new HashMap<>();
40
    public static HashMap< Integer, String> hmResultsHeaderIndexes = new HashMap<>();
41
    public static DecimalFormat df = new DecimalFormat();
42
    public static String trainFileName;
43
    public static String featureSelectionFile;
44
    public static String predictionsResultsFile;
45
    public static String correlatedFeatures;
46
47
    /**
48
     *
49
     * @param trainFilName
50
     * @param featureSelFile
51
     * @param predictionsResFile
52
     * @param type
53
     */
54
    public BestModelSelectionAndReport(String trainFilName,
55
            String featureSelFile,
56
            String predictionsResFile,
57
            String type
58
    ) {
59
        trainFileName = trainFilName;
60
        if (Main.noFeatureSelection) {
61
            featureSelectionFile = trainFilName;
62
        } else {
63
            featureSelectionFile = featureSelFile;
64
        }
65
        predictionsResultsFile = predictionsResFile;
66
        df.setMaximumFractionDigits(3);
67
        DecimalFormatSymbols dfs = new DecimalFormatSymbols();
68
        dfs.setDecimalSeparator('.');
69
        df.setDecimalFormatSymbols(dfs);
70
        String bestOrCombine = "Select best ";
71
72
        if (Main.combineModels) {
73
            bestOrCombine = "Combine ";
74
        }
75
        String sign = " >= ";
76
        boolean metricToMinimize = (Main.bestModelsSortingMetric.contains("RMSE")
77
                || Main.bestModelsSortingMetric.contains("BER")
78
                || Main.bestModelsSortingMetric.contains("FPR")
79
                || Main.bestModelsSortingMetric.contains("FNR")
80
                || Main.bestModelsSortingMetric.contains("FDR")
81
                || Main.bestModelsSortingMetric.contains("MAE")
82
                || Main.bestModelsSortingMetric.contains("RAE")
83
                || Main.bestModelsSortingMetric.contains("RRSE"));
84
        if (metricToMinimize) {
85
            sign = " <= ";
86
        }
87
88
        System.out.println("## " + bestOrCombine + " models using " + Main.bestModelsSortingMetric + " as sorting metric.\n"
89
                + "## Parameters: " + Main.numberOfBestModels + " best models and "
90
                + Main.bestModelsSortingMetric + sign + Main.bestModelsSortingMetricThreshold);
91
92
        //Read results file
93
        boolean classification = type.equals("classification");
94
95
        try {
96
            BufferedReader br = new BufferedReader(new FileReader(predictionsResultsFile));
97
            TreeMap<String, Object> tmModels = new TreeMap<>(); //<metric modelID, classification/regression Object>
98
            HashMap<String, Object> hmModels = new HashMap<>(); //<modelID, classification/regression Object>
99
100
            //in case of RMSE or BER, we want the minimum value instead of the maximal one
101
            if (!metricToMinimize) {
102
                tmModels = new TreeMap<>(Collections.reverseOrder());
103
            }
104
            String line = br.readLine();
105
106
            //fill header mapping
107
            String header = line;
108
            int cpt = 0;
109
            for (String s : header.split("\t")) {
110
                hmResultsHeaderNames.put(s, cpt);
111
                hmResultsHeaderIndexes.put(cpt, s);
112
                cpt++;
113
            }
114
            if (!hmResultsHeaderNames.containsKey(Main.bestModelsSortingMetric)) {
115
                System.err.println("[error] " + Main.bestModelsSortingMetric + " column does not exist in the results file.");
116
                if (Main.doRegression) {
117
                    System.out.println("Use AVG_CC instead since we are in regression mode");
118
                    Main.bestModelsSortingMetric = "AVG_CC";
119
                } else {
120
                    System.exit(0);
121
                }
122
123
            }
124
125
            //read results
126
            while (br.ready()) {
127
                line = br.readLine();
128
                if (!line.trim().isEmpty() && !line.contains("[model error]")) {
129
                    if (classification) {
130
                        try {
131
                            classificationObject co = new classificationObject(line);
132
                            tmModels.put(Double.valueOf(co.hmValues.get(Main.bestModelsSortingMetric)) + " " + co.hmValues.get("ID"), co);
133
                            hmModels.put(co.hmValues.get("ID"), co);
134
                        } catch (Exception e) {
135
                            if (Main.debug) {
136
                                e.printStackTrace();
137
                            }
138
                        }
139
                    } else {
140
                        try {
141
                            regressionObject ro = new regressionObject(line);
142
                            tmModels.put(Double.valueOf(ro.hmValues.get(Main.bestModelsSortingMetric)) + " " + ro.hmValues.get("ID"), ro);
143
                            hmModels.put(ro.hmValues.get("ID"), ro);
144
                        } catch (Exception e) {
145
                            if (Main.debug) {
146
                                e.printStackTrace();
147
                            }
148
                        }
149
                    }
150
                }
151
            }
152
            br.close();
153
154
            //control available models
155
            if (Main.numberOfBestModels > tmModels.size()) {
156
                System.out.println("Only " + tmModels.size() + " available models. You have configured " + Main.numberOfBestModels + " best models");
157
                Main.numberOfBestModels = tmModels.size();
158
            }
159
160
            // get best models list
161
            ArrayList<Object> alBestClassifiers = new ArrayList<>();
162
            cpt = 0;
163
            if (Main.hmTrainingBestModelList.isEmpty()) {
164
                for (String metricAndModel : tmModels.keySet()) {
165
                    cpt++;
166
                    boolean condition = false;
167
                    if (metricToMinimize) {
168
                        condition = Double.valueOf(metricAndModel.split(" ")[0]) < Main.bestModelsSortingMetricThreshold;
169
                    } else {
170
                        condition = Double.valueOf(metricAndModel.split(" ")[0]) > Main.bestModelsSortingMetricThreshold;
171
                    }
172
                    if (condition && cpt <= Main.numberOfBestModels) {
173
                        if (classification) {
174
                            alBestClassifiers.add(((classificationObject) tmModels.get(metricAndModel)));
175
                        } else {
176
                            alBestClassifiers.add(((regressionObject) tmModels.get(metricAndModel)));
177
                        }
178
                    }
179
                }
180
            } else {
181
                for (String modelID : Main.hmTrainingBestModelList.keySet()) {
182
                    if (classification) {
183
                        alBestClassifiers.add(((classificationObject) hmModels.get(modelID)));
184
                    } else {
185
                        alBestClassifiers.add(((regressionObject) hmModels.get(modelID)));
186
                    }
187
                }
188
            }
189
190
            //if model combination vote
191
            if (Main.combineModels) {
192
                if (classification) {
193
                    classificationObject co = new classificationObject();
194
                    co.buildVoteClassifier(alBestClassifiers);
195
                    alBestClassifiers = new ArrayList<>();
196
                    alBestClassifiers.add(co);
197
                } else {
198
                    regressionObject ro = new regressionObject();
199
                    ro.buildVoteClassifier(alBestClassifiers);
200
                    alBestClassifiers = new ArrayList<>();
201
                    alBestClassifiers.add(ro);
202
                }
203
            }
204
205
            //perform evaluations and create models
206
            PrintWriter pw = null;
207
            for (Object classifier : alBestClassifiers) {
208
                // initialize weka module
209
                if (classification) {
210
                    init(featureSelectionFile.replace("infoGain.csv", "infoGain.arff"), classification);
211
                } else {
212
                    init(featureSelectionFile.replace("RELIEFF.csv", "RELIEFF.arff"), classification);
213
                }
214
                createBestModel(classifier, classification, pw, br, false);
215
216
                if (Main.generateModelWithCorrelatedGenes) {
217
                    init(trainFilName, classification);
218
                    createBestModel(classifier, classification, pw, br, true);
219
                    correlatedFeatures = null;
220
                }
221
            }
222
223
        } catch (ClassCastException e) {
224
            e.printStackTrace();
225
            System.err.println("Unable to train selected best model(s). Check input files. ");
226
        } catch (Exception e) {
227
            e.printStackTrace();
228
        }
229
    }
230
231
    private void createBestModel(Object classifier,
232
            Boolean classification,
233
            PrintWriter pw, BufferedReader br, Boolean correlatedFeaturesMode) throws Exception {
234
        String corrMode = "";
235
        if (correlatedFeaturesMode) {
236
            corrMode = "_corr";
237
            System.out.print("\n# Model with correlated features ");
238
            ((classificationObject) classifier).featuresSeparatedByCommas = correlatedFeatures;
239
        } else {
240
            System.out.print("\n# Model ");
241
        }
242
        ArrayList<Double> alMCCs = new ArrayList<>();
243
        ArrayList<Double> alMAEs = new ArrayList<>();
244
        ArrayList<Double> alCCs = new ArrayList<>();
245
246
        if (Main.debug) {
247
            System.out.println("Save model ");
248
        }
249
        String modelFilename;
250
        Weka_module.ClassificationResultsObject cr = null;
251
        Weka_module.RegressionResultsObject rr = null;
252
        classificationObject co = null;
253
        regressionObject ro = null;
254
        String classifierName = "";
255
256
        //saving files
257
        if (classification) {
258
            co = (classificationObject) classifier;
259
            //train model
260
            modelFilename = Main.wd + Main.project
261
                    + "d." + co.classifier + "_" + co.printOptions() + "_"
262
                    + co.optimizer.toUpperCase().trim() + "_" + co.mode + corrMode;
263
            Object trainingOutput = weka.trainClassifier(co.classifier, co.options,
264
                    co.featuresSeparatedByCommas, classification, 10);
265
            cr = (Weka_module.ClassificationResultsObject) trainingOutput;
266
267
            classifierName = co.classifier + "_" + co.printOptions() + "_"
268
                    + co.optimizer.toUpperCase().trim() + "_" + co.mode;
269
            //save feature file
270
            weka.saveFilteredDataToCSV(co.featuresSeparatedByCommas, classification, modelFilename + ".train_features.csv");
271
            //call ranking function
272
            cr.featuresRankingResults = weka.featureRankingForClassification(modelFilename + ".train_features.csv");
273
            //save model
274
            try {
275
                SerializationHelper.write(modelFilename + ".model", cr.model);
276
            } catch (Exception e) {
277
                e.printStackTrace();
278
            }
279
            System.out.println(modelFilename);
280
        } else {
281
            ro = (regressionObject) classifier;
282
            modelFilename = Main.wd + Main.project
283
                    + "d." + ro.classifier + "_" + ro.printOptions() + "_"
284
                    + ro.optimizer.toUpperCase().trim() + "_" + ro.mode;
285
            Object trainingOutput = weka.trainClassifier(ro.classifier, ro.options,
286
                    ro.featuresSeparatedByCommas, classification, 10);
287
            rr = (Weka_module.RegressionResultsObject) trainingOutput;
288
289
            classifierName = ro.classifier + "_" + ro.printOptions() + "_"
290
                    + ro.optimizer.toUpperCase().trim() + "_" + ro.mode;
291
                        
292
            //save model and features
293
            try {
294
                SerializationHelper.write(modelFilename + ".model", rr.model);
295
            } catch (Exception e) {
296
                e.printStackTrace();
297
            }
298
            //save feature file
299
            weka.saveFilteredDataToCSV(ro.featuresSeparatedByCommas, classification, modelFilename + ".train_features.csv");
300
            //call ranking function
301
            rr.featuresRankingResults = weka.featureRankingForRegression(modelFilename + ".train_features.csv");
302
            System.out.println(modelFilename);
303
        }
304
305
        //header
306
        pw = new PrintWriter(new FileWriter(modelFilename + ".details.txt"));
307
        pw.println("## Generated by BioDiscML (Leclercq et al. 2019)##");
308
        pw.println("# Project: " + Main.project.substring(0, Main.project.length() - 1));
309
310
        if (classification) {
311
            pw.println("# ID: " + co.identifier);
312
            System.out.println("# ID: " + co.identifier);
313
            pw.println("# Classifier: " + co.classifier + " " + co.options
314
                    + "\n# Optimizer: " + co.optimizer.toUpperCase()
315
                    + "\n# Feature search mode: " + co.mode);
316
        } else {
317
            pw.println("# ID: " + ro.identifier);
318
            System.out.println("# ID: " + ro.identifier);
319
            pw.println("# Classifier: " + ro.classifier + " " + ro.options
320
                    + "\n# Optimizer: " + ro.optimizer.toUpperCase()
321
                    + "\n# Feature search mode: " + ro.mode);
322
        }
323
324
        //show combined models in case of combined vote
325
        if (Main.combineModels) {
326
            pw.println("# Combined classifiers:");
327
            String combOpt = "";
328
            if (classification) {
329
                combOpt = co.options;
330
            } else {
331
                combOpt = ro.options;
332
            }
333
            for (String s : combOpt.split("-B ")) {
334
                if (s.startsWith("\"weka.classifiers.meta.FilteredClassifier")) {
335
                    String usedFeatures = s.split("Remove -V -R ")[1];
336
                    usedFeatures = usedFeatures.split("\\\\")[0];
337
                    String model = s.substring(s.indexOf("-W ") + 2).trim()
338
                            .replace("-- ", "")
339
                            .replace("\\\"", "\"")
340
                            .replace("\\\\\"", "\\\"");
341
                    model = model.substring(0, model.length() - 1);
342
                    pw.println(model + " (features: " + usedFeatures + ")");
343
                }
344
            }
345
        }
346
        pw.flush();
347
348
        //UpSetR
349
        if (Main.UpSetR) {
350
            UpSetR up = new UpSetR();
351
            up.creatUpSetRDatasetFromSignature(co, featureSelectionFile, predictionsResultsFile);
352
        }
353
354
        //10CV performance
355
        System.out.println("# 10 fold cross validation performance");
356
        pw.println("\n# 10 fold cross validation performance");
357
        if (classification) {
358
            System.out.println(cr.toStringDetails());
359
            alMCCs.add(Double.valueOf(cr.MCC));
360
            alMAEs.add(Double.valueOf(cr.MAE));
361
            pw.println(cr.toStringDetails().replace("[score_training] ", ""));
362
        } else {
363
            System.out.println(rr.toStringDetails());
364
            alCCs.add(Double.valueOf(rr.CC));
365
            alMAEs.add(Double.valueOf(rr.MAE));
366
            pw.println(rr.toStringDetails().replace("[score_training] ", ""));
367
        }
368
        pw.flush();
369
370
        //LOOCV performance
371
        if (Main.loocv) {
372
            System.out.println("# LOOCV (Leave-One-Out cross validation) performance");
373
            pw.println("\n# LOOCV (Leave-One-Out Cross Validation) performance");
374
            if (classification) {
375
                Weka_module.ClassificationResultsObject cr2 = (Weka_module.ClassificationResultsObject) weka.trainClassifier(co.classifier, co.options,
376
                        co.featuresSeparatedByCommas, classification, weka.myData.numInstances());
377
                System.out.println(cr2.toStringDetails());
378
                alMCCs.add(Double.valueOf(cr2.MCC));
379
                alMAEs.add(Double.valueOf(cr2.MAE));
380
                pw.println(cr2.toStringDetails().replace("[score_training] ", ""));
381
            } else {
382
                Weka_module.RegressionResultsObject rr2 = (Weka_module.RegressionResultsObject) weka.trainClassifier(ro.classifier, ro.options,
383
                        ro.featuresSeparatedByCommas, classification, weka.myData.numInstances());
384
                System.out.println(rr2.toStringDetails());
385
                alCCs.add(Double.valueOf(rr2.CC));
386
                alMAEs.add(Double.valueOf(rr2.MAE));
387
                pw.println(rr2.toStringDetails().replace("[score_training] ", ""));
388
            }
389
            pw.flush();
390
        }
391
392
        //REPEATED HOLDOUT performance TRAIN set
393
        ArrayList<Object> alROCs = new ArrayList<>();
394
        Weka_module.evaluationPerformancesResultsObject eproRHTrain = new Weka_module.evaluationPerformancesResultsObject();
395
        if (classification) {
396
            System.out.println("Repeated Holdout evaluation on TRAIN set of " + co.classifier + " " + co.options
397
                    + " optimized by " + co.optimizer + "...");
398
            pw.println("\n#Repeated Holdout evaluation performance on TRAIN set, "
399
                    + Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds");
400
            for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
401
                Weka_module.ClassificationResultsObject cro
402
                        = (Weka_module.ClassificationResultsObject) weka.trainClassifierHoldOutValidation(co.classifier, co.options,
403
                                co.featuresSeparatedByCommas, classification, i);
404
                eproRHTrain.alAUCs.add(Double.valueOf(cro.AUC));
405
                eproRHTrain.alpAUCs.add(Double.valueOf(cro.pAUC));
406
                eproRHTrain.alAUPRCs.add(Double.valueOf(cro.AUPRC));
407
                eproRHTrain.alACCs.add(Double.valueOf(cro.ACC));
408
                eproRHTrain.alSEs.add(Double.valueOf(cro.TPR));
409
                eproRHTrain.alSPs.add(Double.valueOf(cro.TNR));
410
                eproRHTrain.alMCCs.add(Double.valueOf(cro.MCC));
411
                eproRHTrain.alMAEs.add(Double.valueOf(cro.MAE));
412
                eproRHTrain.alBERs.add(Double.valueOf(cro.BER));
413
                alROCs.add(cro);
414
                // System.out.println(i+"\t"+Double.valueOf(cro.AUC));
415
            }
416
            eproRHTrain.computeMeans();
417
            System.out.println(eproRHTrain.toStringClassificationDetails());
418
            alMCCs.add(Double.valueOf(eproRHTrain.meanMCCs));
419
            alMAEs.add(Double.valueOf(eproRHTrain.meanMAEs));
420
            pw.println(eproRHTrain.toStringClassificationDetails().replace("[score_training] ", ""));
421
422
            if (Main.ROCcurves) {
423
                rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc_train.png");
424
            }
425
        } else {
426
            System.out.println("Repeated Holdout evaluation on TRAIN set of " + ro.classifier + " "
427
                    + ro.options + "optimized by " + ro.optimizer.toUpperCase());
428
            pw.println("\n\n#Repeated Holdout evaluation performance on TRAIN set, " + Main.bootstrapAndRepeatedHoldoutFolds + " times average on random seeds");
429
            for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
430
                Weka_module.RegressionResultsObject rro
431
                        = (Weka_module.RegressionResultsObject) weka.trainClassifierHoldOutValidation(ro.classifier, ro.options,
432
                                ro.featuresSeparatedByCommas, classification, i);
433
                eproRHTrain.alCCs.add(Double.valueOf(rro.CC));
434
                eproRHTrain.alMAEs.add(Double.valueOf(rro.MAE));
435
                eproRHTrain.alRMSEs.add(Double.valueOf(rro.RMSE));
436
                eproRHTrain.alRAEs.add(Double.valueOf(rro.RAE));
437
                eproRHTrain.alRRSEs.add(Double.valueOf(rro.RRSE));
438
            }
439
            eproRHTrain.computeMeans();
440
            alCCs.add(Double.valueOf(eproRHTrain.meanCCs));
441
            alMAEs.add(Double.valueOf(eproRHTrain.meanMAEs));
442
            pw.println(eproRHTrain.toStringRegressionDetails().replace("[score_training] ", ""));
443
        }
444
        pw.flush();
445
446
        //BOOTSTRAP performance TRAIN set
447
        double bootstrapTrain632plus = -1;
448
        Weka_module.evaluationPerformancesResultsObject eproBSTrain = new Weka_module.evaluationPerformancesResultsObject();
449
        if (classification) {
450
            System.out.println("Bootstrap evaluation on TRAIN set of " + co.classifier + " " + co.options
451
                    + " optimized by " + co.optimizer + "...");
452
            pw.println("\n#Bootstrap evaluation performance on TRAIN set, "
453
                    + Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds");
454
            for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
455
                Weka_module.ClassificationResultsObject cro
456
                        = (Weka_module.ClassificationResultsObject) weka.trainClassifierBootstrap(co.classifier, co.options,
457
                                co.featuresSeparatedByCommas, classification, i);
458
                eproBSTrain.alAUCs.add(Double.valueOf(cro.AUC));
459
                eproBSTrain.alpAUCs.add(Double.valueOf(cro.pAUC));
460
                eproBSTrain.alAUPRCs.add(Double.valueOf(cro.AUPRC));
461
                eproBSTrain.alACCs.add(Double.valueOf(cro.ACC));
462
                eproBSTrain.alSEs.add(Double.valueOf(cro.TPR));
463
                eproBSTrain.alSPs.add(Double.valueOf(cro.TNR));
464
                eproBSTrain.alMCCs.add(Double.valueOf(cro.MCC));
465
                eproBSTrain.alMAEs.add(Double.valueOf(cro.MAE));
466
                eproBSTrain.alBERs.add(Double.valueOf(cro.BER));
467
                alROCs.add(cro);
468
                // System.out.println(i+"\t"+Double.valueOf(cro.AUC));
469
            }
470
            eproBSTrain.computeMeans();
471
            alMCCs.add(Double.valueOf(eproBSTrain.meanMCCs));
472
            alMAEs.add(Double.valueOf(eproBSTrain.meanMAEs));
473
            System.out.println(eproBSTrain.toStringClassificationDetails());
474
            pw.println(eproBSTrain.toStringClassificationDetails().replace("[score_training] ", ""));
475
476
            //632+ rule
477
            System.out.println("Bootstrap .632+ rule calculated on TRAIN set of " + co.classifier + " " + co.options
478
                    + " optimized by " + co.optimizer + "...");
479
            pw.println("\n#Bootstrap .632+ rule calculated on TRAIN set, "
480
                    + Main.bootstrapAndRepeatedHoldoutFolds + " folds with random seeds");
481
482
            bootstrapTrain632plus = weka.trainClassifierBootstrap632plus(co.classifier, co.options,
483
                    co.featuresSeparatedByCommas);
484
            System.out.println(df.format(bootstrapTrain632plus));
485
            pw.println(df.format(bootstrapTrain632plus));
486
487
            if (Main.ROCcurves) {
488
                rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc_train.png");
489
            }
490
        } else {
491
            System.out.println("Bootstrap evaluation on TRAIN set of " + ro.classifier + " "
492
                    + ro.options + "optimized by " + ro.optimizer.toUpperCase());
493
            pw.println("\n#Bootstrap evaluation performance on TRAIN set, "
494
                    + Main.bootstrapAndRepeatedHoldoutFolds + " times average on random seeds");
495
            for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
496
                Weka_module.RegressionResultsObject rro
497
                        = (Weka_module.RegressionResultsObject) weka.trainClassifierBootstrap(ro.classifier, ro.options,
498
                                ro.featuresSeparatedByCommas, classification, i);
499
                eproBSTrain.alCCs.add(Double.valueOf(rro.CC));
500
                eproBSTrain.alMAEs.add(Double.valueOf(rro.MAE));
501
                eproBSTrain.alRMSEs.add(Double.valueOf(rro.RMSE));
502
                eproBSTrain.alRAEs.add(Double.valueOf(rro.RAE));
503
                eproBSTrain.alRRSEs.add(Double.valueOf(rro.RRSE));
504
            }
505
            eproBSTrain.computeMeans();
506
            alCCs.add(Double.valueOf(eproBSTrain.meanCCs));
507
            alMAEs.add(Double.valueOf(eproBSTrain.meanMAEs));
508
            pw.println(eproBSTrain.toStringRegressionDetails().replace("[score_training] ", ""));
509
        }
510
        pw.flush();
511
512
        // IF TEST SET
513
        try {
514
            if (Main.doSampling) {
515
                alROCs = new ArrayList<>();
516
                System.out.println("Evaluation performance on test set");
517
                pw.println("\n#Evaluation performance on test set");
518
                //get arff test filename generated before training
519
                String arffTestFile = trainFileName.replace("data_to_train.csv", "data_to_test.arff");
520
                //set the outfile of the extracted features needed to test the current model
521
                String arffTestFileWithExtractedModelFeatures = modelFilename + ".test_features.arff";
522
                //check if test set is here
523
                if (!new File(arffTestFile).exists()) {
524
                    pw.println("Test file " + arffTestFile + " not found");
525
                }
526
                //adapt test file to model (extract the needed features)
527
                //test file come from original dataset preprocessed in AdaptDatasetToTraining, so the arff is compatible
528
                Weka_module weka2 = new Weka_module();
529
                weka2.setARFFfile(arffTestFile);
530
                weka2.setDataFromArff();
531
                // create compatible test file
532
                if (Main.combineModels) {
533
                    //combined model contains the filters, we need to keep the same
534
                    //features indexes as the b.featureSelection.infoGain.arff
535
                    weka2.extractFeaturesFromArffFileBasedOnSelectedFeatures(weka.myData,
536
                            weka2.myData, arffTestFileWithExtractedModelFeatures);
537
                } else {
538
                    weka2.extractFeaturesFromTestFileBasedOnModel(modelFilename + ".model",
539
                            weka2.myData, arffTestFileWithExtractedModelFeatures);
540
                }
541
542
                //TESTING
543
                // reload compatible test file in weka2
544
                weka2 = new Weka_module();
545
                weka2.setARFFfile(arffTestFileWithExtractedModelFeatures);
546
                weka2.setDataFromArff();
547
548
                if (classification) {
549
                    Weka_module.ClassificationResultsObject cr2
550
                            = (Weka_module.ClassificationResultsObject) weka2.testClassifierFromFileSource(new File(weka2.ARFFfile),
551
                                    modelFilename + ".model", true);
552
                    alROCs.add(cr2);
553
                    alROCs.add(cr2);
554
                    System.out.println("[score_testing] ACC: " + cr2.ACC);
555
                    System.out.println("[score_testing] AUC: " + cr2.AUC);
556
                    System.out.println("[score_testing] AUPRC: " + cr2.AUPRC);
557
                    System.out.println("[score_testing] SEN: " + cr2.TPR);
558
                    System.out.println("[score_testing] SPE: " + cr2.TNR);
559
                    System.out.println("[score_testing] MCC: " + cr2.MCC);
560
                    System.out.println("[score_testing] MAE: " + cr2.MAE);
561
                    System.out.println("[score_testing] BER: " + cr2.BER);
562
563
                    pw.println("AUC: " + cr2.AUC);
564
                    pw.println("ACC: " + cr2.ACC);
565
                    pw.println("AUPRC: " + cr2.AUPRC);
566
                    pw.println("SEN: " + cr2.TPR);
567
                    pw.println("SPE: " + cr2.TNR);
568
                    pw.println("MCC: " + cr2.MCC);
569
                    pw.println("MAE: " + cr2.MAE);
570
                    pw.println("BER: " + cr2.BER);
571
572
                    alMCCs.add(Double.valueOf(cr2.MCC));
573
                    alMAEs.add(Double.valueOf(cr2.MAE));
574
575
                    if (Main.ROCcurves) {
576
                        rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc_test.png");
577
                    }
578
                } else {
579
                    Weka_module.RegressionResultsObject rr2
580
                            = (Weka_module.RegressionResultsObject) weka2.testClassifierFromFileSource(new File(weka2.ARFFfile),
581
                                    modelFilename + ".model", false);
582
                    System.out.println("[score_testing] Average CC: " + rr2.CC);
583
                    System.out.println("[score_testing] Average RMSE: " + rr2.RMSE);
584
                    //
585
                    pw.println("Average CC: " + rr2.CC);
586
                    pw.println("Average RMSE: " + rr2.RMSE);
587
588
                    alCCs.add(Double.valueOf(rr2.CC));
589
                    alMAEs.add(Double.valueOf(rr2.MAE));
590
                }
591
                new File(arffTestFileWithExtractedModelFeatures).delete();
592
593
                //REPEATED HOLDOUT TRAIN_TEST
594
                arffTestFileWithExtractedModelFeatures = arffTestFileWithExtractedModelFeatures.replace(".test_features.arff", ".RH_features.arff");
595
                Weka_module.evaluationPerformancesResultsObject eproRHTrainTest = new Weka_module.evaluationPerformancesResultsObject();
596
                try {
597
                    alROCs = new ArrayList<>();
598
                    // adapt original dataset file to model (extract the needed features)
599
                    Weka_module weka3 = new Weka_module();
600
                    weka3.setARFFfile(trainFileName.replace("data_to_train.csv", "all_data.arff"));
601
                    weka3.setDataFromArff();
602
                    // create compatible  file
603
                    if (Main.combineModels) {
604
                        //combined model contains the filters, we need to keep the same
605
                        //features indexes as the b.featureSelection.infoGain.arff
606
                        weka3.extractFeaturesFromArffFileBasedOnSelectedFeatures(weka.myData,
607
                                weka3.myData, arffTestFileWithExtractedModelFeatures);
608
                    } else {
609
                        weka3.extractFeaturesFromTestFileBasedOnModel(modelFilename + ".model",
610
                                weka3.myData, arffTestFileWithExtractedModelFeatures);
611
                    }
612
613
                    // reload compatible file in weka2
614
                    weka3 = new Weka_module();
615
                    weka3.setARFFfile(arffTestFileWithExtractedModelFeatures);
616
                    weka3.setDataFromArff();
617
618
                    if (classification) {
619
                        if (!Main.combineModels) {
620
                            weka3.myData = weka3.extractFeaturesFromDatasetBasedOnModel(cr.model, weka3.myData);
621
                        }
622
                        System.out.println("Repeated Holdout evaluation on TRAIN AND TEST sets of " + co.classifier + " " + co.options
623
                                + " optimized by " + co.optimizer);
624
                        pw.println("\n#Repeated Holdout evaluation performance on TRAIN AND TEST set, "
625
                                + Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds");
626
                        for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
627
                            Weka_module.ClassificationResultsObject cro
628
                                    = (Weka_module.ClassificationResultsObject) weka3.trainClassifierHoldOutValidation(co.classifier, co.options,
629
                                            null, classification, i);
630
                            eproRHTrainTest.alAUCs.add(Double.valueOf(cro.AUC));
631
                            eproRHTrainTest.alpAUCs.add(Double.valueOf(cro.pAUC));
632
                            eproRHTrainTest.alAUPRCs.add(Double.valueOf(cro.AUPRC));
633
                            eproRHTrainTest.alACCs.add(Double.valueOf(cro.ACC));
634
                            eproRHTrainTest.alSEs.add(Double.valueOf(cro.TPR));
635
                            eproRHTrainTest.alSPs.add(Double.valueOf(cro.TNR));
636
                            eproRHTrainTest.alMCCs.add(Double.valueOf(cro.MCC));
637
                            eproRHTrainTest.alMAEs.add(Double.valueOf(cro.MAE));
638
                            eproRHTrainTest.alBERs.add(Double.valueOf(cro.BER));
639
                            alROCs.add(cro);
640
                        }
641
                        eproRHTrainTest.computeMeans();
642
                        alMCCs.add(Double.valueOf(eproRHTrainTest.meanMCCs));
643
                        alMAEs.add(Double.valueOf(eproRHTrainTest.meanMAEs));
644
                        System.out.println(eproRHTrainTest.toStringClassificationDetails());
645
                        pw.println(eproRHTrainTest.toStringClassificationDetails().replace("[score_training] ", ""));
646
                        if (Main.ROCcurves) {
647
                            rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc.png");
648
                        }
649
                    } else {
650
                        if (!Main.combineModels) {
651
                            weka3.myData = weka3.extractFeaturesFromDatasetBasedOnModel(rr.model, weka3.myData);
652
                        }
653
654
                        System.out.println("Repeated Holdout evaluation on TRAIN AND TEST sets of " + ro.classifier + " "
655
                                + ro.options + "optimized by " + ro.optimizer.toUpperCase());
656
                        pw.println("\n#Repeated Holdout evaluation performance on TRAIN AND TEST set, "
657
                                + Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds");
658
                        for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
659
                            Weka_module.RegressionResultsObject rro
660
                                    = (Weka_module.RegressionResultsObject) weka3.trainClassifierHoldOutValidation(ro.classifier, ro.options,
661
                                            null, classification, i);
662
                            eproRHTrainTest.alCCs.add(Double.valueOf(rro.CC));
663
                            eproRHTrainTest.alMAEs.add(Double.valueOf(rro.MAE));
664
                            eproRHTrainTest.alRMSEs.add(Double.valueOf(rro.RMSE));
665
                            eproRHTrainTest.alRAEs.add(Double.valueOf(rro.RAE));
666
                            eproRHTrainTest.alRRSEs.add(Double.valueOf(rro.RRSE));
667
                        }
668
                        eproRHTrainTest.computeMeans();
669
                        alCCs.add(Double.valueOf(eproRHTrainTest.meanCCs));
670
                        alMAEs.add(Double.valueOf(eproRHTrainTest.meanMAEs));
671
                        System.out.println(eproRHTrainTest.toStringRegressionDetails());
672
                        pw.println(eproRHTrainTest.toStringRegressionDetails().replace("[score_training] ", ""));
673
                    }
674
                    eproRHTrainTest.computeMeans();
675
                } catch (Exception e) {
676
                    if (Main.debug) {
677
                        e.printStackTrace();
678
                    }
679
                }
680
                new File(arffTestFileWithExtractedModelFeatures).delete();
681
682
                //BOOTSRAP TRAIN_TEST
683
                arffTestFileWithExtractedModelFeatures = arffTestFileWithExtractedModelFeatures.replace(".RH_features.arff", ".BS_features.arff");
684
                Weka_module.evaluationPerformancesResultsObject eproBSTrainTest = new Weka_module.evaluationPerformancesResultsObject();
685
                try {
686
                    alROCs = new ArrayList<>();
687
                    Weka_module weka4 = new Weka_module();
688
                    weka4.setARFFfile(trainFileName.replace("data_to_train.csv", "all_data.arff"));
689
                    weka4.setDataFromArff();
690
691
                    // create compatible  file
692
                    if (Main.combineModels) {
693
                        //combined model contains the filters, we need to keep the same
694
                        //features indexes as the b.featureSelection.infoGain.arff
695
                        weka4.extractFeaturesFromArffFileBasedOnSelectedFeatures(weka.myData,
696
                                weka4.myData, arffTestFileWithExtractedModelFeatures);
697
                    } else {
698
                        weka4.extractFeaturesFromTestFileBasedOnModel(modelFilename + ".model",
699
                                weka4.myData, arffTestFileWithExtractedModelFeatures);
700
                    }
701
702
                    // reload compatible file in weka2
703
                    weka4 = new Weka_module();
704
                    weka4.setARFFfile(arffTestFileWithExtractedModelFeatures);
705
                    weka4.setDataFromArff();
706
                    if (classification) {
707
                        if (!Main.combineModels) {
708
                            weka4.myData = weka4.extractFeaturesFromDatasetBasedOnModel(cr.model, weka4.myData);
709
                        }
710
                        System.out.println("Bootstrap evaluation on TRAIN AND TEST sets of " + co.classifier + " " + co.options
711
                                + " optimized by " + co.optimizer);
712
                        pw.println("\n#Bootstrap evaluation performance on TRAIN AND TEST set, "
713
                                + Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds");
714
                        for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
715
                            Weka_module.ClassificationResultsObject cro
716
                                    = (Weka_module.ClassificationResultsObject) weka4.trainClassifierBootstrap(co.classifier, co.options,
717
                                            null, classification, i);
718
                            eproBSTrainTest.alAUCs.add(Double.valueOf(cro.AUC));
719
                            eproBSTrainTest.alpAUCs.add(Double.valueOf(cro.pAUC));
720
                            eproBSTrainTest.alAUPRCs.add(Double.valueOf(cro.AUPRC));
721
                            eproBSTrainTest.alACCs.add(Double.valueOf(cro.ACC));
722
                            eproBSTrainTest.alSEs.add(Double.valueOf(cro.TPR));
723
                            eproBSTrainTest.alSPs.add(Double.valueOf(cro.TNR));
724
                            eproBSTrainTest.alMCCs.add(Double.valueOf(cro.MCC));
725
                            eproBSTrainTest.alMAEs.add(Double.valueOf(cro.MAE));
726
                            eproBSTrainTest.alBERs.add(Double.valueOf(cro.BER));
727
                            alROCs.add(cro);
728
                        }
729
                        eproBSTrainTest.computeMeans();
730
                        alMCCs.add(Double.valueOf(eproBSTrainTest.meanMCCs));
731
                        alMAEs.add(Double.valueOf(eproBSTrainTest.meanMAEs));
732
                        System.out.println(eproBSTrainTest.toStringClassificationDetails());
733
                        pw.println(eproBSTrainTest.toStringClassificationDetails().replace("[score_training] ", ""));
734
735
                        //632+ rule
736
                        System.out.println("Bootstrap .632+ rule calculated on TRAIN AND TEST set of " + co.classifier + " " + co.options
737
                                + " optimized by " + co.optimizer + "...");
738
                        pw.println("\n#Bootstrap .632+ rule calculated on TRAIN AND TEST set, "
739
                                + Main.bootstrapAndRepeatedHoldoutFolds + " folds with random seeds");
740
741
                        bootstrapTrain632plus = weka4.trainClassifierBootstrap632plus(co.classifier, co.options,
742
                                null);
743
                        System.out.println(df.format(bootstrapTrain632plus));
744
                        pw.println(df.format(bootstrapTrain632plus));
745
746
                        if (Main.ROCcurves) {
747
                            rocCurveGraphs.createRocCurvesWithConfidence(alROCs, classification, modelFilename, ".roc.png");
748
                        }
749
                    } else {
750
                        if (!Main.combineModels) {
751
                            weka4.myData = weka4.extractFeaturesFromDatasetBasedOnModel(rr.model, weka4.myData);
752
                        }
753
                        System.out.println("Bootstrap evaluation on TRAIN AND TEST sets of " + ro.classifier + " "
754
                                + ro.options + "optimized by " + ro.optimizer.toUpperCase());
755
                        pw.println("\n#Bootstrap evaluation performance on TRAIN AND TEST set, "
756
                                + Main.bootstrapAndRepeatedHoldoutFolds + " times weighted average (and standard deviation) on random seeds");
757
                        for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
758
                            Weka_module.RegressionResultsObject rro
759
                                    = (Weka_module.RegressionResultsObject) weka4.trainClassifierBootstrap(ro.classifier, ro.options,
760
                                            null, classification, i);
761
                            eproBSTrainTest.alCCs.add(Double.valueOf(rro.CC));
762
                            eproBSTrainTest.alMAEs.add(Double.valueOf(rro.MAE));
763
                            eproBSTrainTest.alRMSEs.add(Double.valueOf(rro.RMSE));
764
                            eproBSTrainTest.alRAEs.add(Double.valueOf(rro.RAE));
765
                            eproBSTrainTest.alRRSEs.add(Double.valueOf(rro.RRSE));
766
                        }
767
                        eproBSTrainTest.computeMeans();
768
                        alCCs.add(Double.valueOf(eproBSTrainTest.meanCCs));
769
                        alMAEs.add(Double.valueOf(eproBSTrainTest.meanMAEs));
770
                        System.out.println(eproBSTrainTest.toStringRegressionDetails());
771
                        pw.println(eproBSTrainTest.toStringRegressionDetails().replace("[score_training] ", ""));
772
                    }
773
                    eproBSTrainTest.computeMeans();
774
                } catch (Exception e) {
775
                    if (Main.debug) {
776
                        e.printStackTrace();
777
                    }
778
                }
779
780
                //remove test file arff once done
781
                new File(arffTestFileWithExtractedModelFeatures).delete();
782
            }
783
        } catch (Exception e) {
784
            if (Main.debug) {
785
                e.printStackTrace();
786
            }
787
        }
788
        pw.flush();
789
790
        // show average metrics and standard deviation
791
        if (classification) {
792
            pw.println("\n# Average MCC: " + utils.getMean(alMCCs)
793
                    + "\t(" + utils.getStandardDeviation(alMCCs) + ")");
794
            System.out.println("\n# Average MCC: " + utils.getMean(alMCCs));
795
            pw.println("# Average MAE: " + utils.getMean(alMAEs)
796
                    + "\t(" + utils.getStandardDeviation(alMAEs) + ")");
797
            System.out.println("# Average MAE: " + utils.getMean(alMAEs));
798
        } else {
799
            pw.println("\n# Average CC: " + utils.getMean(alCCs)
800
                    + "\t(" + utils.getStandardDeviation(alCCs) + ")");
801
            System.out.println("\n# Average CC: " + utils.getMean(alCCs));
802
            pw.println("# Average MAE: " + utils.getMean(alMAEs)
803
                    + "\t(" + utils.getStandardDeviation(alMAEs) + ")");
804
            System.out.println("# Average MAE: " + utils.getMean(alMAEs));
805
        }
806
807
        //output features
808
        if (classification) {
809
            try {
810
                pw.print("\n# Selected Attributes (Total attributes:" + cr.numberOfFeatures + "). "
811
                        + "Occurrences are shown if you chose combined model\n");
812
                pw.print(cr.features);
813
                pw.println("\n# Attribute ranking by merit calculated by information gain");
814
                pw.print(cr.getFeatureRankingResults());
815
816
            } catch (Exception e) {
817
                e.printStackTrace();
818
            }
819
        } else {
820
            try {
821
                pw.print("\n# Selected Attributes\t(Total attributes:" + rr.numberOfFeatures + "). "
822
                        + "Occurrences are shown if you chose combined model\n");
823
                pw.print(rr.features);
824
                pw.println("\n# Attribute ranking by merit calculated by RELIEFF");
825
                pw.print(rr.getFeatureRankingResults());
826
827
            } catch (Exception e) {
828
                e.printStackTrace();
829
            }
830
        }
831
        pw.flush();
832
833
        //retrieve correlated features
834
        // do not retreive correlated features if we are already computing a model for the long signature
835
        if (Main.retrieveCorrelatedGenes && !correlatedFeaturesMode) {
836
            if (new File(trainFileName).exists()) {
837
                System.out.print("Search correlated features (spearman)...");
838
                pw.println("\n# Correlated features (Spearman)");
839
                pw.println("FeatureInSignature\tSpearmanCorrelationScore\tCorrelatedFeature");
840
                TreeMap<String, Double> tmsCorrelatedgenes
841
                        = RetreiveCorrelatedGenes.spearmanCorrelation(modelFilename + ".train_features.csv", trainFileName);
842
                for (String correlation : tmsCorrelatedgenes.keySet()) {
843
                    pw.println(correlation);
844
                }
845
                if (tmsCorrelatedgenes.isEmpty()) {
846
                    pw.println("#nothing found !");
847
                }
848
                System.out.println("[done]");
849
850
                System.out.print("Search correlated features (pearson)...");
851
                pw.println("\n# Correlated features (Pearson)");
852
                pw.println("FeatureInSignature\tPearsonCorrelationScore\tCorrelatedFeature");
853
                TreeMap<String, Double> tmpCorrelatedgenes
854
                        = RetreiveCorrelatedGenes.pearsonCorrelation(modelFilename + ".train_features.csv", trainFileName);
855
                for (String correlation : tmpCorrelatedgenes.keySet()) {
856
                    pw.println(correlation);
857
                }
858
                if (tmpCorrelatedgenes.isEmpty()) {
859
                    pw.println("#nothing found !");
860
                }
861
                System.out.println("[done]");
862
            } else {
863
                System.out.println("Feature file " + trainFileName + " not found. Unable to calculate correlated genes");
864
            }
865
            pw.flush();
866
867
            //retreive rankings
868
            String ranking = "";
869
            if (classification) {
870
                ranking = weka.featureRankingForClassification(trainFileName.replace("csv", "arff"));
871
            } else {
872
                ranking = weka.featureRankingForRegression(trainFileName.replace("csv", "arff"));
873
            }
874
            String lines[] = ranking.split("\n");
875
            HashMap<String, ArrayList<RankerObject>> hmRanks = new HashMap();
876
            try {
877
                for (String s : lines) {
878
                    s = s.replaceAll(" +", " ");
879
                    if (!s.startsWith("\t") && !s.trim().isEmpty() && s.trim().split(" ").length == 3) {
880
                        RankerObject rankero = new RankerObject(s.trim());
881
                        if (hmRanks.containsKey(rankero.roundedScore)) {
882
                            ArrayList<RankerObject> alRankero = hmRanks.get(rankero.roundedScore);
883
                            alRankero.add(rankero);
884
                            hmRanks.put(rankero.roundedScore, alRankero);
885
                        } else {
886
                            ArrayList<RankerObject> alRankero = new ArrayList<>();
887
                            alRankero.add(rankero);
888
                            hmRanks.put(rankero.roundedScore, alRankero);
889
                        }
890
                    }
891
                }
892
            } catch (Exception e) {
893
                if (Main.debug) {
894
                    e.printStackTrace();
895
                }
896
            }
897
898
            if (Main.retreiveCorrelatedGenesByRankingScore) {
899
                System.out.print("Search similar ranking scores (infogain for classification or relieFf) in the original dataset...");
900
901
                pw.println("\n# Similar ranking score (maximal difference: " + Main.maxRankingScoreDifference + ")");
902
                pw.println("FeatureInSignature\tRankingScore\tFeatureInDataset\tRankingScore");
903
                String rankedFeaturesSign[];
904
                if (classification) {
905
                    rankedFeaturesSign = cr.getFeatureRankingResults().split("\n");
906
                } else {
907
                    rankedFeaturesSign = rr.getFeatureRankingResults().split("\n");
908
                }
909
                for (String featureSign : rankedFeaturesSign) {
910
                    String featureSignIG = featureSign.split("\t")[0];
911
                    String featureSignIGrounded = df.format(Double.valueOf(featureSignIG));
912
                    if (hmRanks.containsKey(featureSignIGrounded)) {
913
                        for (RankerObject alio : hmRanks.get(featureSignIGrounded)) {
914
                            if (!featureSign.contains(alio.feature) && !alio.feature.equals(Main.mergingID)) {
915
                                //max difference between infogains: 0.005
916
                                if (Math.abs(Double.parseDouble(featureSignIG) - Double.parseDouble(alio.infogain))
917
                                        <= Main.maxRankingScoreDifference) {
918
                                    pw.println(featureSign.split("\t")[1] + "\t"
919
                                            + featureSignIG + "\t"
920
                                            + alio.feature + "\t"
921
                                            + alio.infogain);
922
                                }
923
924
                            }
925
                        }
926
                    }
927
                }
928
                System.out.println("[done]");
929
            }
930
931
            //close file
932
            pw.println("\n\n## End of file ##");
933
            pw.close();
934
935
            //export enriched signature
936
            LinkedHashMap<String, String> lhmCorrFeaturesNames = new LinkedHashMap<>();//signature + correlated features
937
            LinkedHashMap<String, String> lhmFeaturesNames = new LinkedHashMap<>();//signature only
938
            try {
939
                br = new BufferedReader(new FileReader(modelFilename + ".details.txt"));
940
                String line = "";
941
                //go to selected attributes
942
                while (!line.startsWith("# Attribute ranking by")) {
943
                    line = br.readLine();
944
                }
945
                line = br.readLine();
946
                //add attributes to hashmap
947
                while (!line.startsWith("#")) {
948
                    if (!line.isEmpty()) {
949
                        lhmCorrFeaturesNames.put(line.split("\t")[1].trim(), "");
950
                        lhmFeaturesNames.put(line.split("\t")[1].trim(), "");
951
                    }
952
                    line = br.readLine();
953
                }
954
                //go to spearman correlated attributes
955
                while (!line.startsWith("FeatureInSignature")) {
956
                    line = br.readLine();
957
                }
958
                line = br.readLine();
959
                //add attributes to hashmap
960
                while (!line.startsWith("#")) {
961
                    if (!line.isEmpty()) {
962
                        lhmCorrFeaturesNames.put(line.split("\t")[2].trim(), "");
963
                    }
964
                    line = br.readLine();
965
                }
966
                //go to pearson correlated attributes
967
                while (!line.startsWith("FeatureInSignature")) {
968
                    line = br.readLine();
969
                }
970
                line = br.readLine();
971
                //add attributes to hashmap
972
                while (!line.startsWith("#")) {
973
                    if (!line.isEmpty()) {
974
                        lhmCorrFeaturesNames.put(line.split("\t")[2].trim(), "");
975
                    }
976
                    line = br.readLine();
977
                }
978
                if (Main.retreiveCorrelatedGenesByRankingScore) {
979
                    //go to infogain correlated attributes
980
                    while (!line.startsWith("FeatureInSignature")) {
981
                        line = br.readLine();
982
                    }
983
                    line = br.readLine();
984
                    //add attributes to hashmap
985
                    while (br.ready()) {
986
                        if (!line.isEmpty()) {
987
                            lhmCorrFeaturesNames.put(line.split("\t")[0].trim(), "");
988
                            lhmCorrFeaturesNames.put(line.split("\t")[2].trim(), "");
989
                        }
990
                        line = br.readLine();
991
                    }
992
                }
993
                br.close();
994
                //write correlated feature file from training file
995
                correlatedFeatures = writeFeaturesFile(lhmCorrFeaturesNames, trainFileName,
996
                        classification, modelFilename + ".train_corrFeatures.csv");
997
                if (Main.doSampling) {
998
                    //write correlated feature file from all data file
999
                    //if we have done a sampling, then we can't go from trainFeaturesFile
1000
                    //but to allFeaturesFile, which contain test data
1001
                    correlatedFeatures = writeFeaturesFile(lhmCorrFeaturesNames, trainFileName.replace("data_to_train", "all_data"),
1002
                            classification, modelFilename + ".all_corrFeatures.csv");
1003
1004
                    //short signature
1005
                    writeFeaturesFile(lhmFeaturesNames, trainFileName.replace("data_to_train", "all_data"),
1006
                            classification, modelFilename + ".all_features.csv");
1007
                    System.out.println("");
1008
                }
1009
1010
            } catch (Exception e) {
1011
                e.printStackTrace();
1012
            }
1013
1014
        }
1015
        // delete useless files
1016
        if (Main.doSampling) {
1017
            new File(modelFilename + ".train_features.csv").delete();
1018
            new File(modelFilename + ".train_corrFeatures.csv").delete();
1019
        }
1020
        new File(modelFilename + ".test_features.arff").delete();
1021
        new File(modelFilename + ".RH_features.arff").delete();
1022
        new File(modelFilename + ".BS_features.arff").delete();
1023
1024
    }
1025
1026
    /**
1027
     *
1028
     * @param lhm feature names in order
1029
     * @param originFile the training file or all data file
1030
     * @param classification if we are doing a classification
1031
     * @param outfile outfile name
1032
     */
1033
    private String writeFeaturesFile(LinkedHashMap<String, String> lhm, String originFile, boolean classification, String outfile) {
1034
        //find columns indices
1035
        String featuresSeparatedByCommas = "1";
1036
        try {
1037
            BufferedReader br = new BufferedReader(new FileReader(originFile));
1038
            String header = br.readLine();
1039
            String features[] = header.split(utils.detectSeparator(originFile));
1040
            for (int i = 0; i < features.length; i++) {
1041
                String feature = features[i];
1042
                if (lhm.containsKey(feature)) {
1043
                    featuresSeparatedByCommas += "," + (i + 1);
1044
                }
1045
            }
1046
            featuresSeparatedByCommas += "," + features.length;
1047
        } catch (Exception e) {
1048
            e.printStackTrace();
1049
        }
1050
1051
        //extract columns using weka filter
1052
        Weka_module weka2 = new Weka_module();
1053
        weka2.setARFFfile(originFile.replace("csv", "arff"));
1054
        weka2.setDataFromArff();
1055
        weka2.saveFilteredDataToCSV(featuresSeparatedByCommas,
1056
                classification, outfile);
1057
        return featuresSeparatedByCommas;
1058
1059
    }
1060
1061
    /**
1062
     * initialize weka
1063
     *
1064
     * @param infile
1065
     * @param classification
1066
     */
1067
    private static void init(String infile, boolean classification) {
1068
        //convert csv to arff
1069
        if (infile.endsWith(".csv")) {
1070
            weka.setCSVFile(new File(infile));
1071
            weka.csvToArff(classification);
1072
        } else {
1073
            weka.setARFFfile(infile.replace(".csv", ".arff"));
1074
        }
1075
1076
        //set local variable of weka object from ARFFfile
1077
        weka.setDataFromArff();
1078
        weka.myData = weka.convertStringsToNominal(weka.myData);
1079
//        // check if class has numeric values, hence regression, instead of nominal class (classification)
1080
        classification = weka.isClassification();
1081
    }
1082
1083
    /**
1084
     * calculate mean and standard deviation of an array of doubles
1085
     *
1086
     * @param al
1087
     * @return
1088
     */
1089
    private static String getMeanAndStandardDeviation(ArrayList<Double> al) {
1090
1091
        double d[] = new double[al.size()];
1092
        for (int i = 0; i < al.size(); i++) {
1093
            d[i] = (double) al.get(i);
1094
        }
1095
1096
        StandardDeviation sd = new StandardDeviation();
1097
        Mean m = new Mean();
1098
1099
        return df.format(m.evaluate(d)) + " (" + df.format(sd.evaluate(d)) + ")";
1100
    }
1101
1102
    /**
1103
     * classification object
1104
     */
1105
    public static class classificationObject {
1106
1107
        public ArrayList<String> featureList = new ArrayList<>();
1108
        public String featuresSeparatedByCommas = "";
1109
        public String optimizer = "";
1110
        public String mode = "";
1111
        public String classifier = "";
1112
        public String options = "";
1113
        public String identifier = "";
1114
        public TreeMap<Integer, Integer> tmFeatures;
1115
        public HashMap<String, String> hmValues = new HashMap<>(); //Column name, value
1116
1117
        public classificationObject() {
1118
        }
1119
1120
        public classificationObject(String line) {
1121
1122
            identifier = line.split("\t")[hmResultsHeaderNames.get("ID")];
1123
            classifier = line.split("\t")[hmResultsHeaderNames.get("classifier")];
1124
            options = line.split("\t")[hmResultsHeaderNames.get("Options")];
1125
            optimizer = line.split("\t")[hmResultsHeaderNames.get("OptimizedValue")];
1126
            mode = line.split("\t")[hmResultsHeaderNames.get("SearchMode")];
1127
1128
            featureList.addAll(Arrays.asList(line.split("\t")[hmResultsHeaderNames.get("AttributeList")].split(",")));
1129
            featuresSeparatedByCommas = line.split("\t")[hmResultsHeaderNames.get("AttributeList")];
1130
            String s[] = line.split("\t");
1131
            for (int i = 0; i < s.length; i++) {
1132
                hmValues.put(hmResultsHeaderIndexes.get(i), s[i]);
1133
            }
1134
1135
        }
1136
1137
        /**
1138
         * for combined models
1139
         *
1140
         * @param alBestClassifiers
1141
         * @param NumberOfTopModels
1142
         */
1143
        private void buildVoteClassifier(ArrayList<Object> alBestClassifiers) {
1144
            classifier = "meta.Vote"; //Combination rule: average of probabilities
1145
            options = "-S 1 -R " + Main.combinationRule;
1146
1147
            tmFeatures = new TreeMap<>();
1148
            for (Object c : alBestClassifiers) {
1149
                classificationObject co = (classificationObject) c;
1150
                //create filteredclassifier with selected attributes
1151
                String filteredClassifierOptions = "-B \"weka.classifiers.meta.FilteredClassifier -F \\\"weka.filters.unsupervised.attribute.Remove -V -R "
1152
                        + co.featuresSeparatedByCommas.substring(2) + "\\\"";
1153
                //add classifier and its options
1154
                String classif = "-W weka.classifiers." + co.classifier + " --";
1155
                String classifOptions = co.options.replace("\\", "\\\\").replace("\"", "\\\"") + "\"";
1156
                options += " " + filteredClassifierOptions + " " + classif + " " + classifOptions;
1157
1158
                //get all features seen and get their number of views
1159
                for (String f : co.featuresSeparatedByCommas.split(",")) {
1160
                    if (tmFeatures.containsKey(Integer.valueOf(f))) {
1161
                        int i = tmFeatures.get(Integer.valueOf(f));
1162
                        i++;
1163
                        tmFeatures.put(Integer.valueOf(f), i);
1164
                    } else {
1165
                        tmFeatures.put(Integer.valueOf(f), 1);
1166
                    }
1167
                }
1168
            }
1169
            //set features lists
1170
            for (Integer f : tmFeatures.keySet()) {
1171
                featureList.add(f.toString());
1172
                featuresSeparatedByCommas += "," + f;
1173
            }
1174
            featuresSeparatedByCommas = featuresSeparatedByCommas.substring(1);
1175
1176
            //set other variables
1177
            optimizer = "COMB";
1178
            mode = Main.numberOfBestModels + "_" + Main.bestModelsSortingMetric + "_" + Main.bestModelsSortingMetricThreshold;
1179
        }
1180
1181
        /**
1182
         * printable version of options
1183
         *
1184
         * @return
1185
         */
1186
        private String printOptions() {
1187
            if (classifier.contains("meta.Vote")) {
1188
                return options.substring(0, options.indexOf("-B")).replace(" ", "");
1189
            } else {
1190
                return options.replace(" ", "").replace("\\", "").replace("\"", "");
1191
            }
1192
        }
1193
1194
    }
1195
1196
    public static class regressionObject {
1197
1198
        public ArrayList<String> featureList = new ArrayList<>();
1199
        public String featuresSeparatedByCommas = "";
1200
        public String classifier;
1201
        public String optimizer;
1202
        public String options;
1203
        public String mode;
1204
        public String identifier;
1205
        public TreeMap<Integer, Integer> tmFeatures;
1206
        public HashMap<String, String> hmValues = new HashMap<>(); //Column name, value
1207
1208
        public regressionObject() {
1209
        }
1210
1211
        public regressionObject(String line) {
1212
            identifier = line.split("\t")[hmResultsHeaderNames.get("ID")];
1213
            classifier = line.split("\t")[hmResultsHeaderNames.get("classifier")];
1214
            options = line.split("\t")[hmResultsHeaderNames.get("Options")];
1215
            optimizer = line.split("\t")[hmResultsHeaderNames.get("OptimizedValue")];
1216
            mode = line.split("\t")[hmResultsHeaderNames.get("SearchMode")];
1217
            featureList.addAll(Arrays.asList(line.split("\t")[hmResultsHeaderNames.get("AttributeList")].split(",")));
1218
            featuresSeparatedByCommas = line.split("\t")[hmResultsHeaderNames.get("AttributeList")];
1219
1220
            if (options.startsWith("\"")) {
1221
                options = options.substring(1); //remove first "
1222
                options = options.replace("\\\"\"", "\\\""); //  replace \"" by \"
1223
                options = options.replace("\"\"\"", "\""); // replace """ by "
1224
                options = options.replace("\"\"weka", "\"weka"); // replace  "" by "
1225
            }
1226
1227
            String s[] = line.split("\t");
1228
            for (int i = 0; i < s.length; i++) {
1229
                hmValues.put(hmResultsHeaderIndexes.get(i), s[i]);
1230
            }
1231
        }
1232
1233
        /**
1234
         * for combined models
1235
         *
1236
         * @param alBestClassifiers
1237
         * @param cc
1238
         * @param NumberOfTopModels
1239
         */
1240
        private void buildVoteClassifier(ArrayList<Object> alBestClassifiers) {
1241
            classifier = "meta.Vote"; //Combination rule: average of probabilities
1242
            options = "-S 1 -R " + Main.combinationRule;
1243
1244
            tmFeatures = new TreeMap<>();
1245
            int cpt = 0;
1246
            for (Object r : alBestClassifiers) {
1247
                regressionObject ro = (regressionObject) r;
1248
                cpt++;
1249
                //create filteredclassifier with selected attributes
1250
                String filteredClassifierOptions = "-B \"weka.classifiers.meta.FilteredClassifier -F \\\"weka.filters.unsupervised.attribute.Remove -V -R "
1251
                        + ro.featuresSeparatedByCommas.substring(2) + "\\\"";
1252
                //add classifier and its options
1253
                String classif = "-W weka.classifiers." + ro.classifier + " --";
1254
                String classifOptions = ro.options.replace("\\", "\\\\").replace("\"", "\\\"") + "\"";
1255
                options += " " + filteredClassifierOptions + " " + classif + " " + classifOptions;
1256
1257
                //get all features seen and get their number of views
1258
                for (String f : ro.featuresSeparatedByCommas.split(",")) {
1259
                    if (tmFeatures.containsKey(Integer.valueOf(f))) {
1260
                        int i = tmFeatures.get(Integer.valueOf(f));
1261
                        i++;
1262
                        tmFeatures.put(Integer.valueOf(f), i);
1263
                    } else {
1264
                        tmFeatures.put(Integer.valueOf(f), 1);
1265
                    }
1266
                }
1267
1268
            }
1269
1270
            //set features lists
1271
            for (Integer f : tmFeatures.keySet()) {
1272
                featureList.add(f.toString());
1273
                featuresSeparatedByCommas += "," + f;
1274
            }
1275
            featuresSeparatedByCommas = featuresSeparatedByCommas.substring(1);
1276
1277
            //set other variables
1278
            optimizer = "COMB";
1279
            mode = Main.numberOfBestModels + "_" + Main.bestModelsSortingMetric + "_" + Main.bestModelsSortingMetricThreshold;
1280
        }
1281
1282
        /**
1283
         * printable version of options
1284
         *
1285
         * @return
1286
         */
1287
        private String printOptions() {
1288
            if (classifier.contains("meta.Vote")) {
1289
                return options.substring(0, options.indexOf("-B")).replace(" ", "");
1290
            } else {
1291
                return options.replace(" ", "").replace("\\", "").replace("\"", "");
1292
            }
1293
        }
1294
1295
    }
1296
1297
    private static class RankerObject {
1298
1299
        public String infogain;
1300
        public String roundedScore;
1301
        public String feature;
1302
1303
        public RankerObject() {
1304
1305
        }
1306
1307
        private RankerObject(String s) {
1308
            infogain = s.split(" ")[0];
1309
            roundedScore = df.format(Double.valueOf(s.split(" ")[0]));
1310
            feature = s.split(" ")[2];
1311
        }
1312
    }
1313
1314
}