Switch to unified view

a b/src/biodiscml/Training.java
1
/*
2
 * execute machine learning algorithms in parrallel
3
 * Works with regression and classification
4
 */
5
package biodiscml;
6
7
import java.io.BufferedReader;
8
import java.io.File;
9
import java.io.FileReader;
10
import java.io.FileWriter;
11
import java.io.PrintWriter;
12
import java.text.DecimalFormat;
13
import java.text.DecimalFormatSymbols;
14
import java.time.Duration;
15
import java.time.Instant;
16
import java.util.ArrayList;
17
import java.util.Collections;
18
import java.util.HashMap;
19
import java.util.List;
20
import java.util.Random;
21
import org.apache.commons.math3.stat.descriptive.moment.*;
22
import utils.Weka_module;
23
import utils.utils;
24
import weka.attributeSelection.AttributeSelection;
25
import weka.classifiers.Classifier;
26
27
/**
28
 *
29
 * @author Mickael
30
 */
31
public class Training {
32
33
    public static Weka_module weka = new Weka_module();
34
    public static AttributeSelection ranking;
35
    public static ArrayList<String[]> alClassifiers = new ArrayList<>();
36
    public static boolean isClassification = true;
37
    public static String resultsSummaryHeader = "";
38
    public static int cptPassed = 0;
39
    public static int cptFailed = 0;
40
    public boolean parrallel = true;
41
    public static String trainFileName = "";
42
    public static DecimalFormat df = new DecimalFormat();
43
    public static PrintWriter pw;
44
45
    public Training() {
46
    }
47
48
    /**
49
     * @param dataToTrainModel
50
     * @param resultsFile
51
     * @param featureSelectionFile
52
     * @param type. Is "class" for classification, else "reg" for regression
53
     */
54
    public Training(String dataToTrainModel, String resultsFile,
55
            String featureSelectionFile, String type) {
56
        // deleting previous run
57
        if (!Main.restoreRun && !Main.resumeTraining) {
58
            if (new File(featureSelectionFile).exists()) {
59
                System.out.println("\t" + featureSelectionFile + " exist... deleting...");
60
                new File(featureSelectionFile).delete();
61
                new File(featureSelectionFile.replace(".csv", ".arff")).delete();
62
            }
63
            if (new File(resultsFile).exists()) {
64
                System.out.println("\t" + resultsFile + " exist... deleting...");
65
                new File(resultsFile).delete();
66
                new File(resultsFile.replace(".csv", ".arff")).delete();
67
            }
68
        }
69
70
        df.setMaximumFractionDigits(3);
71
        DecimalFormatSymbols dfs = new DecimalFormatSymbols();
72
        dfs.setDecimalSeparator('.');
73
        df.setDecimalFormatSymbols(dfs);
74
        if (Main.cpus.equals("1")) {
75
            parrallel = false;
76
        }
77
        System.out.println("Training on " + dataToTrainModel + ". All tested models results will be in " + resultsFile);
78
        alClassifiers = new ArrayList<>();
79
        isClassification = type.equals("class");
80
        trainFileName = dataToTrainModel;
81
//        File f = new File(Main.CVfolder);//cross validation folder outputs
82
//        if (!f.exists()) {
83
//            f.mkdir();
84
//        }
85
86
        //convert csv to arff
87
        if (dataToTrainModel.endsWith(".csv") && !new File(dataToTrainModel.replace(".csv", ".arff")).exists()) {
88
            weka.setCSVFile(new File(dataToTrainModel));
89
            weka.csvToArff(isClassification);
90
        } else {
91
            weka.setARFFfile(dataToTrainModel.replace(".csv", ".arff"));
92
        }
93
94
        //set local variable of weka object from ARFFfile
95
        weka.setDataFromArff();
96
        weka.myData = weka.convertStringsToNominal(weka.myData);
97
98
//        // check if class has numeric values, hence regression, instead of nominal class (classification)
99
        //classification = weka.isClassification();
100
        //CLASSIFICATION
101
        if (isClassification) {
102
            if (Main.debug) {
103
                System.out.println("Only use classification algorithms");
104
            }
105
            //HEADER FOR SUMMARY OUTPUT
106
            resultsSummaryHeader = "ID"
107
                    + "\tclassifier"
108
                    + "\tOptions"
109
                    + "\tOptimizedValue"
110
                    + "\tSearchMode"
111
                    + "\tnbrOfFeatures"
112
                    //10CV
113
                    + "\tTRAIN_10CV_ACC"
114
                    + "\tTRAIN_10CV_AUC"
115
                    + "\tTRAIN_10CV_AUPRC"
116
                    + "\tTRAIN_10CV_SEN"
117
                    + "\tTRAIN_10CV_SPE"
118
                    + "\tTRAIN_10CV_MCC"
119
                    + "\tTRAIN_10CV_MAE"
120
                    + "\tTRAIN_10CV_BER"
121
                    + "\tTRAIN_10CV_FPR"
122
                    + "\tTRAIN_10CV_FNR"
123
                    + "\tTRAIN_10CV_PPV"
124
                    + "\tTRAIN_10CV_FDR"
125
                    + "\tTRAIN_10CV_Fscore"
126
                    + "\tTRAIN_10CV_kappa"
127
                    + "\tTRAIN_matrix"
128
                    //LOOCV
129
                    + "\tTRAIN_LOOCV_ACC"
130
                    + "\tTRAIN_LOOCV_AUC"
131
                    + "\tTRAIN_LOOCV_AUPRC"
132
                    + "\tTRAIN_LOOCV_SEN"
133
                    + "\tTRAIN_LOOCV_SPE"
134
                    + "\tTRAIN_LOOCV_MCC"
135
                    + "\tTRAIN_LOOCV_MAE"
136
                    + "\tTRAIN_LOOCV_BER"
137
                    //Repeated Holdout TRAIN
138
                    + "\tTRAIN_RH_ACC"
139
                    + "\tTRAIN_RH_AUC"
140
                    + "\tTRAIN_RH_AUPRC"
141
                    + "\tTRAIN_RH_SEN"
142
                    + "\tTRAIN_RH_SPE"
143
                    + "\tTRAIN_RH_MCC"
144
                    + "\tTRAIN_RH_MAE"
145
                    + "\tTRAIN_RH_BER"
146
                    //Bootstrap TRAIN
147
                    + "\tTRAIN_BS_ACC"
148
                    + "\tTRAIN_BS_AUC"
149
                    + "\tTRAIN_BS_AUPRC"
150
                    + "\tTRAIN_BS_SEN"
151
                    + "\tTRAIN_BS_SPE"
152
                    + "\tTRAIN_BS_MCC"
153
                    + "\tTRAIN_BS_MAE"
154
                    + "\tTRAIN_BS_BER"
155
                    //Bootstrap .632+ TRAIN
156
                    + "\tTRAIN_BS.632+"
157
                    //test
158
                    + "\tTEST_ACC"
159
                    + "\tTEST_AUC"
160
                    + "\tTEST_AUPRC"
161
                    + "\tTEST_SEN"
162
                    + "\tTEST_SPE"
163
                    + "\tTEST_MCC"
164
                    + "\tTEST_MAE"
165
                    + "\tTEST_BER"
166
                    //Repeated Holdout TRAIN_TEST
167
                    + "\tTRAIN_TEST_RH_ACC"
168
                    + "\tTRAIN_TEST_RH_AUC"
169
                    + "\tTRAIN_TEST_RH_AUPRC"
170
                    + "\tTRAIN_TEST_RH_SEN"
171
                    + "\tTRAIN_TEST_RH_SPE"
172
                    + "\tTRAIN_TEST_RH_MCC"
173
                    + "\tTRAIN_TEST_RH_MAE"
174
                    + "\tTRAIN_TEST_RH_BER"
175
                    //Bootstrap TRAIN_TEST
176
                    + "\tTRAIN_TEST_BS_ACC"
177
                    + "\tTRAIN_TEST_BS_AUC"
178
                    + "\tTRAIN_TEST_BS_AUPRC"
179
                    + "\tTRAIN_TEST_BS_SEN"
180
                    + "\tTRAIN_TEST_BS_SPE"
181
                    + "\tTRAIN_TEST_BS_MCC"
182
                    + "\tTRAIN_TEST_BS_MAE"
183
                    + "\tTRAIN_TEST_BS_BER"
184
                    //Bootstrap .632+ TRAIN
185
                    + "\tTRAIN_TEST_BS.632+"
186
                    //stats
187
                    + "\tAVG_BER"
188
                    + "\tSTD_BER"
189
                    + "\tAVG_MAE"
190
                    + "\tSTD_MAE"
191
                    + "\tAVG_MCC"
192
                    + "\tSTD_MCC"
193
                    + "\tAttributeList";
194
195
            //fast way classification. Predetermined commands.
196
            if (Main.classificationFastWay) {
197
                //hmclassifiers.put(new String[]{"misc.VFI", "-B 0.6", "AUC"}, "");
198
                //hmclassifiers.put(new String[]{"meta.CostSensitiveClassifier", "-cost-matrix \"[0.0 ratio; 100.0 0.0]\" -S 1 -W weka.classifiers.misc.VFI -- -B 0.4", "MCC"}, "");
199
200
                for (String cmd : Main.classificationFastWayCommands) {
201
                    String optimizer = cmd.split(":")[1];
202
                    String searchmode = cmd.split(":")[2];
203
                    String classifier = cmd.split(":")[0].split(" ")[0];
204
                    String options = cmd.split(":")[0].replace(classifier, "");
205
                    //case where optimizer AND search modes are empty
206
                    if (optimizer.equals("allopt") && searchmode.equals("allsearch")) {
207
                        for (String allOptimizers : Main.classificationOptimizers.split(",")) {
208
                            for (String allSearchModes : Main.searchmodes.split(",")) {
209
                                alClassifiers.add(new String[]{classifier, options, allOptimizers.trim(), allSearchModes.trim()});
210
                            }
211
                        }
212
                        //Case where only searchmode is empty
213
                    } else if (!optimizer.equals("allopt") && searchmode.equals("allsearch")) {
214
                        for (String allSearchModes : Main.searchmodes.split(",")) {
215
                            alClassifiers.add(new String[]{classifier, options, optimizer.trim(), allSearchModes.trim()});
216
                        }
217
                        //case where only optimizer is empty
218
                    } else if (optimizer.equals("allopt") && !searchmode.equals("allsearch")) {
219
                        for (String allOptimizers : Main.classificationOptimizers.split(",")) {
220
                            alClassifiers.add(new String[]{classifier, options, allOptimizers.trim(), searchmode.trim()});
221
                        }
222
                        //case where optimizer and searchmode are provided
223
                    } else {
224
                        alClassifiers.add(new String[]{classifier, options, optimizer, searchmode});
225
                    }
226
227
                }
228
            } else {
229
                // if no feature selection is asked
230
                if (Main.noFeatureSelection) {
231
                    Main.maxNumberOfSelectedFeatures = weka.myData.numAttributes();
232
                    Main.maxNumberOfFeaturesInModel = weka.myData.numAttributes();
233
                    for (String cmd : Main.classificationBruteForceCommands) {
234
                        String classifier = cmd.split(" ")[0];
235
                        String options = cmd.replace(classifier, "").trim();
236
                        addClassificationToQueue(classifier, options);
237
                    }
238
239
                } else {
240
                    //brute force classification, try everything in the provided classifiers and optimizers
241
                    for (String cmd : Main.classificationBruteForceCommands) {
242
                        String classifier = cmd.split(" ")[0];
243
                        String options = cmd.replace(classifier, "").trim();
244
                        addClassificationToQueue(classifier, options);
245
                    }
246
                }
247
            }
248
249
            if (!Main.noFeatureSelection) {
250
                System.out.print("Feature selection and ranking...");
251
                if (new File(featureSelectionFile).exists() && Main.resumeTraining) {
252
                    System.out.print("\nFeature selection and ranking already done... skipping");
253
                } else {
254
                    //ATTRIBUTE SELECTION for classification
255
                    weka.attributeSelectionByInfoGainRankingAndSaveToCSV(featureSelectionFile);
256
                    //get the rank of attributes
257
                    ranking = weka.featureRankingForClassification();
258
                    System.out.println("[done]");
259
                    //reset arff and keep compatible header
260
                    weka.setCSVFile(new File(featureSelectionFile));
261
                    weka.csvToArff(isClassification);
262
                    weka.makeCompatibleARFFheaders(dataToTrainModel.replace("data_to_train.csv", "data_to_train.arff"),
263
                            featureSelectionFile.replace("infoGain.csv", "infoGain.arff"));
264
                }
265
                weka.setARFFfile(featureSelectionFile.replace("infoGain.csv", "infoGain.arff"));
266
                weka.setDataFromArff();
267
            }
268
269
        } else {
270
            //REGRESSION
271
            if (Main.debug) {
272
                System.out.println("Only use regression algorithms");
273
            }
274
            resultsSummaryHeader = "ID"
275
                    + "\tclassifier"
276
                    + "\tOptions"
277
                    + "\tOptimizedValue"
278
                    + "\tSearchMode"
279
                    + "\tnbrOfFeatures"
280
                    //10cv
281
                    + "\tTRAIN_10CV_CC"
282
                    + "\tTRAIN_10CV_MAE"
283
                    + "\tTRAIN_10CV_RMSE"
284
                    + "\tTRAIN_10CV_RAE"
285
                    + "\tTRAIN_10CV_RRSE"
286
                    //LOOCV
287
                    + "\tTRAIN_LOOCV_CC"
288
                    + "\tTRAIN_LOOCV_MAE"
289
                    + "\tTRAIN_LOOCV_RMSE"
290
                    + "\tTRAIN_LOOCV_RAE"
291
                    + "\tTRAIN_LOOCV_RRSE"
292
                    //Repeated Holdout
293
                    + "\tTRAIN_RH_CC"
294
                    + "\tTRAIN_RH_MAE"
295
                    + "\tTRAIN_RH_RMSE"
296
                    + "\tTRAIN_RH_RAE"
297
                    + "\tTRAIN_RH_RRSE"
298
                    //Bootstrap
299
                    + "\tTRAIN_BS_CC"
300
                    + "\tTRAIN_BS_MAE"
301
                    + "\tTRAIN_BS_RMSE"
302
                    + "\tTRAIN_BS_RAE"
303
                    + "\tTRAIN_BS_RRSE"
304
                    //TEST SET
305
                    + "\tTEST_CC"
306
                    + "\tTEST_MAE"
307
                    + "\tTEST_RMSE"
308
                    + "\tTEST_RAE"
309
                    + "\tTEST_RRSE"
310
                    //Repeated Holdout TRAIN_TEST
311
                    + "\tTRAIN_TEST_RH_CC"
312
                    + "\tTRAIN_TEST_RH_MAE"
313
                    + "\tTRAIN_TEST_RH_RMSE"
314
                    + "\tTRAIN_TEST_RH_RAE"
315
                    + "\tTRAIN_TEST_RH_RRSE"
316
                    //Bootstrap TRAIN_TEST
317
                    + "\tTRAIN_TEST_BS_CC"
318
                    + "\tTRAIN_TEST_BS_MAE"
319
                    + "\tTRAIN_TEST_BS_RMSE"
320
                    + "\tTRAIN_TEST_BS_RAE"
321
                    + "\tTRAIN_TEST_BS_RRSE"
322
                    //stats
323
                    + "\tAVG_CC"
324
                    + "\tSTD_CC"
325
                    + "\tAVG_MAE"
326
                    + "\tSTD_MAE"
327
                    + "\tAVG_RMSE"
328
                    + "\tSTD_RMSE"
329
                    + "\tAttributeList";
330
331
            //fast way regression. Predetermined commands.
332
            if (Main.regressionFastWay) {
333
                //hmclassifiers.put(new String[]{"functions.GaussianProcesses", "-L 1.0 -N 0 -K \"weka.classifiers.functions.supportVector.RBFKernel -C 250007 -G 1.0\"", "CC"}, "");
334
                //hmclassifiers.put(new String[]{"meta.AdditiveRegression",
335
                //       "-S 1.0 -I 10 -W weka.classifiers.functions.GaussianProcesses -- -L 1.0 -N 0 -K \"weka.classifiers.functions.supportVector.PolyKernel -C 250007 -E 1.0 -L\"", "CC"}, "");
336
337
                for (String cmd : Main.regressionFastWayCommands) {
338
                    String optimizer = cmd.split(":")[1];
339
                    String searchmode = cmd.split(":")[2];
340
                    String classifier = cmd.split(":")[0].split(" ")[0];
341
                    String options = cmd.split(":")[0].replace(classifier, "");
342
343
                    //case where optimizer AND search modes are empty
344
                    if (optimizer.equals("allopt") && searchmode.equals("allsearch")) {
345
                        for (String allOptimizers : Main.regressionOptimizers.split(",")) {
346
                            for (String allSearchModes : Main.searchmodes.split(",")) {
347
                                alClassifiers.add(new String[]{classifier, options, allOptimizers.trim(), allSearchModes.trim()});
348
                            }
349
                        }
350
                        //Case where only searchmode is empty
351
                    } else if (!optimizer.equals("allopt") && searchmode.equals("allsearch")) {
352
                        for (String allSearchModes : Main.searchmodes.split(",")) {
353
                            alClassifiers.add(new String[]{classifier, options, optimizer.trim(), allSearchModes.trim()});
354
                        }
355
                        //case where only optimizer is empty
356
                    } else if (optimizer.equals("allopt") && !searchmode.equals("allsearch")) {
357
                        for (String allOptimizers : Main.regressionOptimizers.split(",")) {
358
                            alClassifiers.add(new String[]{classifier, options, allOptimizers.trim(), searchmode.trim()});
359
                        }
360
                        //case where optimizer and searchmode are provided
361
                    } else {
362
                        alClassifiers.add(new String[]{classifier, options, optimizer, searchmode});
363
                    }
364
                }
365
            } else {
366
                //brute force classification, try everything in the provided classifiers and optimizers
367
                for (String cmd : Main.regressionBruteForceCommands) {
368
                    String classifier = cmd.split(" ")[0];
369
                    String options = cmd.replace(classifier, "");
370
                    addRegressionToQueue(classifier, options);
371
                }
372
            }
373
374
            //ATTRIBUTE SELECTION for regression
375
            if (!Main.noFeatureSelection) {
376
                System.out.print("Selecting attributes and ranking by RelieFF...");
377
                if (new File(featureSelectionFile).exists() && Main.resumeTraining) {
378
                    System.out.print("Selecting attributes and ranking by RelieFF already done... skipped by resumeTraining");
379
                } else {
380
                    weka.attributeSelectionByRelieFFAndSaveToCSV(featureSelectionFile);
381
382
                    System.out.println("[done]");
383
384
                    //reset arff and keep compatible header
385
                    weka.setCSVFile(new File(featureSelectionFile));
386
                    weka.csvToArff(isClassification);
387
                    weka.makeCompatibleARFFheaders(dataToTrainModel.replace("data_to_train.csv", "data_to_train.arff"),
388
                            featureSelectionFile.replace("RELIEFF.csv", "RELIEFF.arff"));
389
                    weka.setARFFfile(featureSelectionFile.replace("RELIEFF.csv", "RELIEFF.arff"));
390
                    weka.setDataFromArff();
391
                }
392
            }
393
394
        }
395
        //resume training, remove from alClassifier all classifiers already trained
396
        if (Main.resumeTraining && !Main.restoreRun) {
397
            HashMap<String, String> hm = new HashMap<>();
398
            //get data
399
            try {
400
                BufferedReader br = new BufferedReader(new FileReader(resultsFile));
401
                br.readLine(); //skip header
402
                while (br.ready()) {
403
                    String line[] = br.readLine().split("\t");
404
                    if (!line[0].startsWith("ERROR")) {
405
                        String s[] = new String[4];
406
                        s[0] = line[1];
407
                        s[1] = line[2];
408
                        s[2] = line[3].toLowerCase();
409
                        s[3] = line[4].toLowerCase();
410
                        hm.put(line[1] + "\t" + line[2] + "\t" + line[3].toLowerCase() + "\t" + line[4].toLowerCase(), "");
411
                    }
412
                }
413
            } catch (Exception e) {
414
                if (Main.debug) {
415
                    e.printStackTrace();
416
                }
417
            }
418
            //remove from alClassifiers
419
            int alClassifiersBeforeRemoval = alClassifiers.size();
420
            for (int i = 0; i < alClassifiers.size(); i++) {
421
                String s = alClassifiers.get(i)[0] + "\t" + alClassifiers.get(i)[1] + "\t" + alClassifiers.get(i)[2] + "\t" + alClassifiers.get(i)[3];
422
                if (hm.containsKey(s)) {
423
                    alClassifiers.remove(i);
424
                }
425
            }
426
            int alClassifiersAfterRemoval = alClassifiers.size();
427
            int totalRemoved = alClassifiersBeforeRemoval - alClassifiersAfterRemoval;
428
            if (Main.debug) {
429
                System.out.println("Total removed from alClassifier after resumeTraining = " + totalRemoved);
430
            }
431
432
            System.out.println("ResumeTraining: Remains " + alClassifiersAfterRemoval
433
                    + " classifiers to train on the " + alClassifiersBeforeRemoval);
434
435
        }
436
437
        try {
438
            //PREPARE OUTPUT
439
            if (Main.resumeTraining && !Main.restoreRun) {
440
                pw = new PrintWriter(new FileWriter(resultsFile, true));
441
                //pw.println("Resumed here");
442
                pw.flush();
443
            } else {
444
                pw = new PrintWriter(new FileWriter(resultsFile));
445
                pw.println(resultsSummaryHeader);
446
            }
447
448
            //EXECUTE IN PARRALLEL
449
            System.out.println("Total classifiers to test: " + alClassifiers.size());
450
451
            if (parrallel) {
452
                alClassifiers
453
                        .parallelStream()
454
                        .map((classif) -> StepWiseFeatureSelectionTraining(classif[0], classif[1], classif[2], classif[3]))
455
                        .map((s) -> {
456
                            if (!s.toLowerCase().contains("error")) {
457
                                pw.println(s);
458
                                pw.flush();
459
                            } else if (Main.printFailedModels) {
460
                                pw.println(s);
461
                                pw.flush();
462
                            }
463
                            return s;
464
                        })
465
                        .sorted()
466
                        .forEach((_item) -> {
467
                            pw.flush();
468
                        });
469
470
            } else {
471
                alClassifiers.stream().map((classif) -> {
472
                    String s = StepWiseFeatureSelectionTraining(classif[0], classif[1], classif[2], classif[3]);
473
                    if (!s.toLowerCase().contains("error")) {
474
                        pw.println(s);
475
                    } else if (Main.printFailedModels) {
476
                        pw.println(s);
477
                    }
478
                    return classif;
479
                }).forEach((_item) -> {
480
                    pw.flush();
481
                });
482
            }
483
            pw.close();
484
485
            //END
486
            System.out.println("Total model tested: " + cptPassed + "/" + alClassifiers.size()
487
                    + ", including " + cptFailed + " incompatible models");
488
489
        } catch (Exception e) {
490
            if (Main.debug) {
491
                e.printStackTrace();
492
            }
493
            System.out.println("Total model tested: " + cptPassed + "/" + alClassifiers.size()
494
                    + ", including " + cptFailed + " incompatible models");
495
            pw.close();
496
497
        } catch (Error err) {
498
            if (Main.debug) {
499
                err.printStackTrace();
500
            }
501
            System.out.println("Total model tested: " + cptPassed + "/" + alClassifiers.size()
502
                    + ", including " + cptFailed + " incompatible models");
503
            pw.close();
504
        }
505
    }
506
507
    /**
508
     * StepWise training. We add the features one by one in the model if they
509
     * improve the value to maximize We only test the first 500 attributes (or
510
     * less if the number of attributes is less than 500)
511
     *
512
     * @param classifier
513
     * @param classifier_options
514
     * @param valueToMaximizeOrMinimize
515
     * @return
516
     */
517
    private static String StepWiseFeatureSelectionTraining(String classifier, String classifier_options,
518
            String valueToMaximizeOrMinimize, String searchMethod) {
519
        String out = classifier + "\t" + classifier_options + "\t" + valueToMaximizeOrMinimize.toUpperCase() + "\t" + searchMethod;
520
        Instant start = Instant.now();
521
        System.out.println("[model] (" + (cptPassed++) + "/" + alClassifiers.size() + ")" + out);
522
        String lastOutput = "";
523
524
        boolean minimize = valueToMaximizeOrMinimize.equals("fdr")
525
                || valueToMaximizeOrMinimize.equals("mae")
526
                || valueToMaximizeOrMinimize.equals("rmse")
527
                || valueToMaximizeOrMinimize.equals("ber")
528
                || valueToMaximizeOrMinimize.equals("rae")
529
                || valueToMaximizeOrMinimize.equals("rrse");
530
531
        try {
532
            Object o = null;
533
            int numberOfAttributes = 0;
534
            double previousMeasureToMaximize = -1000.0;
535
            double previousMeasureToMinimize = 1000.0;
536
            ArrayList<Integer> alAttributes = new ArrayList<>();
537
            int classIndex = weka.myData.numAttributes();
538
            for (int i = 1; i <= classIndex; i++) { // weka starts at 1, not 0
539
                alAttributes.add(i);
540
            }
541
            AttributeOject ao = new AttributeOject(alAttributes);
542
            int cpt = 0;
543
            Weka_module.ClassificationResultsObject cr = null;
544
            Weka_module.RegressionResultsObject rr = null;
545
546
            // SHORT TEST to ensure compatibility of the model
547
            if (Main.performShortTest) {
548
                o = weka.shortTestTrainClassifier(classifier, classifier_options,
549
                        ao.getAttributesIdClassInString(), isClassification);
550
                if (o.getClass().getName().equals("java.lang.String")) {
551
                    if (Main.debug) {
552
                        System.out.println("SHORT TEST FAILED");
553
                    }
554
                    return "ERROR\t" + o.toString();
555
                }
556
            }
557
            if (Main.debug) {
558
                System.out.println("\tGoing to training");
559
            }
560
561
            //all features search, no feature selection
562
            if (searchMethod.startsWith("all")) {
563
                int top = ao.alAttributes.size();
564
                //Create attribute list
565
                for (int i = 0; i < top; i++) { //from ID to class (excluded)
566
                    //add new attribute to the set of retainedAttributes
567
                    ao.addNewAttributeToRetainedAttributes(i);
568
                }
569
                //TRAIN 10CV
570
                o = weka.trainClassifier(classifier, classifier_options,
571
                        ao.getRetainedAttributesIdClassInString(), isClassification, 10);
572
573
                if (!(o instanceof String) || o == null) {
574
                    if (isClassification) {
575
                        cr = (Weka_module.ClassificationResultsObject) o;
576
                    } else {
577
                        rr = (Weka_module.RegressionResultsObject) o;
578
                    }
579
                }
580
            }
581
582
            //TOP X features search
583
            if (searchMethod.startsWith("top")) {
584
                int top = Integer.valueOf(searchMethod.replace("top", ""));
585
                if (top <= ao.alAttributes.size()) {
586
                    //Create attribute list
587
                    for (int i = 0; i < top; i++) { //from ID to class (excluded)
588
                        //add new attribute to the set of retainedAttributes
589
                        ao.addNewAttributeToRetainedAttributes(i);
590
                    }
591
                    //TRAIN 10CV
592
                    o = weka.trainClassifier(classifier, classifier_options,
593
                            ao.getRetainedAttributesIdClassInString(), isClassification, 10);
594
595
                    if (!(o instanceof String) || o == null) {
596
                        if (isClassification) {
597
                            cr = (Weka_module.ClassificationResultsObject) o;
598
                        } else {
599
                            rr = (Weka_module.RegressionResultsObject) o;
600
                        }
601
                    }
602
                } else {
603
                    o = null;
604
                }
605
            }
606
607
            //FORWARD AND FORWARD-BACKWARD search AND BACKWARD AND BACKWARD-FORWARD search
608
            if (searchMethod.startsWith("f") || searchMethod.startsWith("b")) {
609
                boolean doForwardBackward_OR_BackwardForward = searchMethod.equals("fb") || searchMethod.equals("bf");
610
                boolean StartBackwardInsteadForward = searchMethod.startsWith("b");
611
                if (StartBackwardInsteadForward) {
612
                    Collections.reverse(ao.alAttributes);
613
                }
614
615
                for (int i = 0; i < ao.alAttributes.size(); i++) { //from ID to class (excluded)
616
                    cpt++;
617
                    //add new attribute to the set of retainedAttributes
618
                    ao.addNewAttributeToRetainedAttributes(i);
619
                    Weka_module.ClassificationResultsObject oldcr = cr;
620
                    //do feature selection by forward(-backward)
621
                    if (ao.retainedAttributesOnly.size() <= Main.maxNumberOfFeaturesInModel) {
622
                        o = weka.trainClassifier(classifier, classifier_options,
623
                                ao.getRetainedAttributesIdClassInString(), isClassification, 10);
624
625
                        if (o instanceof String || o == null) {
626
                            return (String) o;
627
                        } else if (isClassification) {
628
                            cr = (Weka_module.ClassificationResultsObject) o;
629
                        } else {
630
                            rr = (Weka_module.RegressionResultsObject) o;
631
                        }
632
633
                        //choose what we want to maximize or minimize (such as error rates)
634
                        //this will crash if model had an error
635
                        double currentMeasure = getValueToMaximize(valueToMaximizeOrMinimize, cr, rr);
636
637
                        //Report results
638
                        boolean modelIsImproved = false;
639
                        //i<2 is to avoid return no attribute at all
640
                        // test if model is improved
641
                        if (minimize) {
642
                            modelIsImproved = (i < 2 && currentMeasure <= previousMeasureToMinimize) || currentMeasure < previousMeasureToMinimize;
643
                        } else {
644
                            modelIsImproved = (i < 2 && currentMeasure >= previousMeasureToMaximize) || currentMeasure > previousMeasureToMaximize;
645
                        }
646
647
                        if (modelIsImproved) {
648
                            if (!minimize) {
649
                                previousMeasureToMaximize = currentMeasure;
650
                            } else {
651
                                previousMeasureToMinimize = currentMeasure;
652
                            }
653
654
                            // do backward OR forward, check if we have an improvement if we remove previously chosen features
655
                            if (doForwardBackward_OR_BackwardForward && numberOfAttributes > 1) {
656
                                oldcr = cr;
657
                                ArrayList<Integer> attributesToTestInBackward = ao.getRetainedAttributesIdClassInArrayList();
658
                                for (int j = 1/*skip ID*/;
659
                                        j < attributesToTestInBackward.size() - 2/*skip last attribute we added by forward and class*/; j++) {
660
                                    attributesToTestInBackward.remove(j);
661
662
                                    String featuresToTest = utils.arrayToString(attributesToTestInBackward, ",");
663
                                    //train
664
                                    if (isClassification) {
665
                                        cr = (Weka_module.ClassificationResultsObject) weka.trainClassifier(classifier, classifier_options,
666
                                                featuresToTest, isClassification, 10);
667
                                    } else {
668
                                        rr = (Weka_module.RegressionResultsObject) weka.trainClassifier(classifier, classifier_options,
669
                                                featuresToTest, isClassification, 10);
670
                                    }
671
                                    //get measure
672
                                    double measureWithRemovedFeature = getValueToMaximize(valueToMaximizeOrMinimize, cr, rr);
673
                                    //check if we have improvement
674
                                    if (minimize) {
675
                                        modelIsImproved = (measureWithRemovedFeature <= currentMeasure)
676
                                                || measureWithRemovedFeature < currentMeasure;
677
                                    } else {
678
                                        modelIsImproved = (measureWithRemovedFeature >= currentMeasure)
679
                                                || measureWithRemovedFeature > currentMeasure;
680
                                    }
681
682
                                    if (modelIsImproved) {
683
                                        //if model is improved definitly discard feature is improvement
684
                                        ao.changeRetainedAttributes(featuresToTest);
685
                                        oldcr = cr;
686
                                        currentMeasure = measureWithRemovedFeature;
687
                                    } else {
688
                                        //restore the feature
689
                                        attributesToTestInBackward = ao.getRetainedAttributesIdClassInArrayList();
690
                                        cr = oldcr;
691
                                    }
692
                                }
693
                            }
694
//
695
                            //modify results summary output
696
                            //only for DEBUG purposes
697
                            if (isClassification) {
698
                                lastOutput = out
699
                                        + "\t" + cr.numberOfFeatures + "\t" + cr.toString() + "\t" + ao.getRetainedAttributesIdClassInString();
700
701
                            } else {
702
                                lastOutput = out
703
                                        + "\t" + rr.numberOfFeatures + "\t" + rr.toString() + "\t" + ao.getRetainedAttributesIdClassInString();
704
                            }
705
706
                        } else {
707
                            //back to previous attribute if no improvement with the new attribute and go to the next
708
                            ao.retainedAttributesOnly.remove(ao.retainedAttributesOnly.size() - 1);
709
                            cr = oldcr;
710
                        }
711
                    }
712
                }
713
            }
714
715
            //now we have a good model, evaluate the performance using various approaches
716
            if (!(o instanceof String) && o != null) {
717
                //LOOCV
718
                Weka_module.ClassificationResultsObject crLoocv = null;
719
                Weka_module.RegressionResultsObject rrLoocv = null;
720
                String loocvOut = "";
721
                if (Main.loocv) {
722
                    if (Main.debug) {
723
                        System.out.println("\tLOOCV Train set " + Main.bootstrapAndRepeatedHoldoutFolds + " times");
724
                    }
725
                    Object oLoocv;
726
                    if (isClassification) {
727
                        oLoocv = weka.trainClassifier(classifier, classifier_options,
728
                                ao.getRetainedAttributesIdClassInString(), isClassification,
729
                                cr.dataset.numInstances());
730
                    } else {
731
                        oLoocv = weka.trainClassifier(classifier, classifier_options,
732
                                ao.getRetainedAttributesIdClassInString(), isClassification,
733
                                rr.dataset.numInstances());
734
                    }
735
736
                    if (oLoocv instanceof String || oLoocv == null) {
737
                        if (Main.debug) {
738
                            System.err.println("[error] LOOCV failed");
739
                        }
740
                    } else if (isClassification) {
741
                        crLoocv = (Weka_module.ClassificationResultsObject) oLoocv;
742
                        loocvOut = crLoocv.toStringShort();
743
                    } else {
744
                        rrLoocv = (Weka_module.RegressionResultsObject) oLoocv;
745
                        loocvOut = rrLoocv.toString();
746
                    }
747
748
                } else {
749
                    if (isClassification) {
750
                        loocvOut = "\t" + "\t" + "\t" + "\t" + "\t" + "\t" + "\t";
751
                    } else {
752
                        loocvOut = "\t" + "\t" + "\t" + "\t";
753
                    }
754
                }
755
                //Repeated Holdout
756
                Weka_module.evaluationPerformancesResultsObject eproRHTrain = new Weka_module.evaluationPerformancesResultsObject();
757
                if (Main.repeatedHoldout) {
758
                    if (Main.debug) {
759
                        System.out.println("\tRepeated Holdout on Train set " + Main.bootstrapAndRepeatedHoldoutFolds + " times");
760
                    }
761
                    if (isClassification) {
762
                        for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
763
                            Weka_module.ClassificationResultsObject cro
764
                                    = (Weka_module.ClassificationResultsObject) weka.trainClassifierHoldOutValidation(classifier, classifier_options,
765
                                            ao.getRetainedAttributesIdClassInString(), isClassification, i);
766
                            eproRHTrain.alAUCs.add(Double.valueOf(cro.AUC));
767
                            eproRHTrain.alpAUCs.add(Double.valueOf(cro.pAUC));
768
                            eproRHTrain.alAUPRCs.add(Double.valueOf(cro.AUPRC));
769
                            eproRHTrain.alACCs.add(Double.valueOf(cro.ACC));
770
                            eproRHTrain.alSEs.add(Double.valueOf(cro.TPR));
771
                            eproRHTrain.alSPs.add(Double.valueOf(cro.TNR));
772
                            eproRHTrain.alMCCs.add(Double.valueOf(cro.MCC));
773
                            eproRHTrain.alMAEs.add(Double.valueOf(cro.MAE));
774
                            eproRHTrain.alBERs.add(Double.valueOf(cro.BER));
775
                        }
776
                    } else {
777
                        for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
778
                            Weka_module.RegressionResultsObject rro
779
                                    = (Weka_module.RegressionResultsObject) weka.trainClassifierHoldOutValidation(classifier, classifier_options,
780
                                            ao.getRetainedAttributesIdClassInString(), isClassification, i);
781
                            eproRHTrain.alCCs.add(Double.valueOf(rro.CC));
782
                            eproRHTrain.alMAEs.add(Double.valueOf(rro.MAE));
783
                            eproRHTrain.alRMSEs.add(Double.valueOf(rro.RMSE));
784
                            eproRHTrain.alRAEs.add(Double.valueOf(rro.RAE));
785
                            eproRHTrain.alRRSEs.add(Double.valueOf(rro.RRSE));
786
                        }
787
                    }
788
                    eproRHTrain.computeMeans();
789
                }
790
791
                //BOOTSTRAP AND BOOTSTRAP .632+ rule TRAIN
792
                if (Main.debug) {
793
                    System.out.println("\tBootstrapping on Train set " + Main.bootstrapAndRepeatedHoldoutFolds + " times");
794
                }
795
                double bootstrapTrain632plus = 1000;
796
                Weka_module.evaluationPerformancesResultsObject eproBSTrain
797
                        = new Weka_module.evaluationPerformancesResultsObject();
798
                if (Main.bootstrap) {
799
                    if (isClassification) {
800
                        bootstrapTrain632plus = weka.trainClassifierBootstrap632plus(classifier, classifier_options,
801
                                ao.getRetainedAttributesIdClassInString());
802
                        for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
803
                            Weka_module.ClassificationResultsObject cro
804
                                    = (Weka_module.ClassificationResultsObject) weka.trainClassifierBootstrap(classifier, classifier_options,
805
                                            ao.getRetainedAttributesIdClassInString(), isClassification, i);
806
                            eproBSTrain.alAUCs.add(Double.valueOf(cro.AUC));
807
                            eproBSTrain.alpAUCs.add(Double.valueOf(cro.pAUC));
808
                            eproBSTrain.alAUPRCs.add(Double.valueOf(cro.AUPRC));
809
                            eproBSTrain.alACCs.add(Double.valueOf(cro.ACC));
810
                            eproBSTrain.alSEs.add(Double.valueOf(cro.TPR));
811
                            eproBSTrain.alSPs.add(Double.valueOf(cro.TNR));
812
                            eproBSTrain.alMCCs.add(Double.valueOf(cro.MCC));
813
                            eproBSTrain.alMAEs.add(Double.valueOf(cro.MAE));
814
                            eproBSTrain.alBERs.add(Double.valueOf(cro.BER));
815
                        }
816
                    } else {
817
                        for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
818
                            Weka_module.RegressionResultsObject rro
819
                                    = (Weka_module.RegressionResultsObject) weka.trainClassifierBootstrap(classifier, classifier_options,
820
                                            ao.getRetainedAttributesIdClassInString(), isClassification, i);
821
822
                            eproBSTrain.alCCs.add(Double.valueOf(rro.CC));
823
                            eproBSTrain.alMAEs.add(Double.valueOf(rro.MAE));
824
                            eproBSTrain.alRMSEs.add(Double.valueOf(rro.RMSE));
825
                            eproBSTrain.alRAEs.add(Double.valueOf(rro.RAE));
826
                            eproBSTrain.alRRSEs.add(Double.valueOf(rro.RRSE));
827
                        }
828
                    }
829
                    eproBSTrain.computeMeans();
830
                }
831
832
                // TEST SET
833
                if (Main.debug) {
834
                    System.out.println("\tEvalutation on Test set ");
835
                }
836
                Weka_module.testResultsObject tro = new Weka_module.testResultsObject();
837
                String testResults = null;
838
                if (isClassification) {
839
                    testResults = "\t" + "\t" + "\t" + "\t" + "\t" + "\t" + "\t";
840
                } else {
841
                    testResults = "\t" + "\t" + "\t" + "\t";
842
                }
843
                try {
844
                    if (Main.doSampling) {
845
                        //testing
846
                        Weka_module weka2 = new Weka_module();
847
                        if (isClassification) {
848
                            Weka_module.ClassificationResultsObject cr2 = null;
849
                            try {
850
                                cr2 = (Weka_module.ClassificationResultsObject) weka2.testClassifierFromModel(cr.model,
851
                                        trainFileName.replace("data_to_train.csv", "data_to_test.csv"),//test file
852
                                        Main.isClassification, out);
853
                            } catch (Exception e) {
854
                                if (Main.debug) {
855
                                    e.printStackTrace();
856
                                }
857
                            }
858
                            tro.ACC = cr2.ACC;
859
                            tro.AUC = cr2.AUC;
860
                            tro.AUPRC = cr2.AUPRC;
861
                            tro.SE = cr2.TPR;
862
                            tro.SP = cr2.TNR;
863
                            tro.MCC = cr2.MCC;
864
                            tro.MAE = cr2.MAE;
865
                            tro.BER = cr2.BER;
866
                            testResults = tro.toStringClassification();
867
868
                        } else {
869
                            Weka_module.RegressionResultsObject rr2
870
                                    = (Weka_module.RegressionResultsObject) weka2.testClassifierFromModel(rr.model,
871
                                            trainFileName.replace("data_to_train.csv", "data_to_test.csv"),//test file
872
                                            Main.isClassification, out);
873
                            tro.CC = rr2.CC;
874
                            tro.RRSE = rr2.RRSE;
875
                            tro.RMSE = rr2.RMSE;
876
                            tro.MAE = rr2.MAE;
877
                            tro.RAE = rr2.RAE;
878
                            testResults = tro.toStringRegression();
879
                        }
880
                    }
881
                } catch (Exception e) {
882
                    if (Main.debug) {
883
                        e.printStackTrace();
884
                    }
885
                }
886
887
                //REPEATED HOLDOUT TRAIN_TEST
888
                if (Main.debug) {
889
                    System.out.println("\tRepeated Holdout on Train AND Test set "
890
                            + Main.bootstrapAndRepeatedHoldoutFolds + " times");
891
                }
892
                Weka_module.evaluationPerformancesResultsObject eproRHTrainTest
893
                        = new Weka_module.evaluationPerformancesResultsObject();
894
                try {
895
                    if (Main.doSampling && Main.repeatedHoldout) {
896
                        Weka_module weka2 = new Weka_module();
897
                        weka2.setARFFfile(trainFileName.replace("data_to_train.csv", "all_data.arff"));
898
                        weka2.setDataFromArff();
899
                        if (isClassification) {
900
                            weka2.myData = weka2.extractFeaturesFromDatasetBasedOnModel(cr.model, weka2.myData);
901
                            for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
902
                                Weka_module.ClassificationResultsObject cro
903
                                        = (Weka_module.ClassificationResultsObject) weka2.trainClassifierHoldOutValidation(
904
                                                classifier, classifier_options,
905
                                                null, isClassification, i);
906
907
                                eproRHTrainTest.alAUCs.add(Double.valueOf(cro.AUC));
908
                                eproRHTrainTest.alpAUCs.add(Double.valueOf(cro.pAUC));
909
                                eproRHTrainTest.alAUPRCs.add(Double.valueOf(cro.AUPRC));
910
                                eproRHTrainTest.alACCs.add(Double.valueOf(cro.ACC));
911
                                eproRHTrainTest.alSEs.add(Double.valueOf(cro.TPR));
912
                                eproRHTrainTest.alSPs.add(Double.valueOf(cro.TNR));
913
                                eproRHTrainTest.alMCCs.add(Double.valueOf(cro.MCC));
914
                                eproRHTrainTest.alMAEs.add(Double.valueOf(cro.MAE));
915
                                eproRHTrainTest.alBERs.add(Double.valueOf(cro.BER));
916
                            }
917
                        } else {
918
                            weka2.myData = weka2.extractFeaturesFromDatasetBasedOnModel(rr.model, weka2.myData);
919
                            for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
920
                                Weka_module.RegressionResultsObject rro
921
                                        = (Weka_module.RegressionResultsObject) weka2.trainClassifierHoldOutValidation(
922
                                                classifier, classifier_options,
923
                                                null, isClassification, i);
924
                                eproRHTrainTest.alCCs.add(Double.valueOf(rro.CC));
925
                                eproRHTrainTest.alMAEs.add(Double.valueOf(rro.MAE));
926
                                eproRHTrainTest.alRMSEs.add(Double.valueOf(rro.RMSE));
927
                                eproRHTrainTest.alRAEs.add(Double.valueOf(rro.RAE));
928
                                eproRHTrainTest.alRRSEs.add(Double.valueOf(rro.RRSE));
929
                            }
930
                        }
931
                        eproRHTrainTest.computeMeans();
932
                    }
933
                } catch (Exception e) {
934
                    if (Main.debug) {
935
                        e.printStackTrace();
936
                    }
937
                }
938
939
                //BOOTSTRAP TRAIN_TEST AND BOOTSTRAP .632+ rule TRAIN_TEST
940
                if (Main.debug) {
941
                    System.out.println("\tBootstrapping on Train AND Test set " + Main.bootstrapAndRepeatedHoldoutFolds + " times");
942
                }
943
                double bootstrapTrainTest632plus = 1000;
944
                Weka_module.evaluationPerformancesResultsObject eproBSTrainTest = new Weka_module.evaluationPerformancesResultsObject();
945
                try {
946
                    if (Main.doSampling && Main.bootstrap) {
947
                        Weka_module weka2 = new Weka_module();
948
                        weka2.setARFFfile(trainFileName.replace("data_to_train.csv", "all_data.arff"));
949
                        weka2.setDataFromArff();
950
                        if (isClassification) {
951
                            weka2.myData = weka2.extractFeaturesFromDatasetBasedOnModel(cr.model, weka2.myData);
952
                            bootstrapTrainTest632plus = weka2.trainClassifierBootstrap632plus(classifier, classifier_options,
953
                                    null);
954
                            for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
955
                                Weka_module.ClassificationResultsObject cro
956
                                        = (Weka_module.ClassificationResultsObject) weka2.trainClassifierBootstrap(
957
                                                classifier, classifier_options,
958
                                                null, isClassification, i);
959
960
                                eproBSTrainTest.alAUCs.add(Double.valueOf(cro.AUC));
961
                                eproBSTrainTest.alpAUCs.add(Double.valueOf(cro.pAUC));
962
                                eproBSTrainTest.alAUPRCs.add(Double.valueOf(cro.AUPRC));
963
                                eproBSTrainTest.alACCs.add(Double.valueOf(cro.ACC));
964
                                eproBSTrainTest.alSEs.add(Double.valueOf(cro.TPR));
965
                                eproBSTrainTest.alSPs.add(Double.valueOf(cro.TNR));
966
                                eproBSTrainTest.alMCCs.add(Double.valueOf(cro.MCC));
967
                                eproBSTrainTest.alMAEs.add(Double.valueOf(cro.MAE));
968
                                eproBSTrainTest.alBERs.add(Double.valueOf(cro.BER));
969
                            }
970
                        } else {
971
                            weka2.myData = weka2.extractFeaturesFromDatasetBasedOnModel(rr.model, weka2.myData);
972
                            for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) {
973
                                Weka_module.RegressionResultsObject rro
974
                                        = (Weka_module.RegressionResultsObject) weka2.trainClassifierBootstrap(
975
                                                classifier, classifier_options,
976
                                                null, isClassification, i);
977
                                eproBSTrainTest.alCCs.add(Double.valueOf(rro.CC));
978
                                eproBSTrainTest.alMAEs.add(Double.valueOf(rro.MAE));
979
                                eproBSTrainTest.alRMSEs.add(Double.valueOf(rro.RMSE));
980
                                eproBSTrainTest.alRAEs.add(Double.valueOf(rro.RAE));
981
                                eproBSTrainTest.alRRSEs.add(Double.valueOf(rro.RRSE));
982
                            }
983
                        }
984
                        eproBSTrainTest.computeMeans();
985
                    }
986
                } catch (Exception e) {
987
                    if (Main.debug) {
988
                        e.printStackTrace();
989
                    }
990
                }
991
992
                //STATISTICS AND OUTPUT
993
                if (isClassification) {
994
                    //average BER
995
                    List<Double> alBERs = new ArrayList<>();
996
                    try {
997
                        alBERs.add(Double.valueOf(cr.BER));
998
                    } catch (Exception e) {
999
                    }
1000
                    try {
1001
                        alBERs.add(Double.valueOf(eproRHTrain.meanBERs));
1002
                    } catch (Exception e) {
1003
                    }
1004
                    try {
1005
                        alBERs.add(Double.valueOf(eproBSTrain.meanBERs));
1006
                    } catch (Exception e) {
1007
                    }
1008
                    try {
1009
                        alBERs.add(Double.valueOf(crLoocv.BER));
1010
                    } catch (Exception e) {
1011
                    }
1012
                    try {
1013
                        alBERs.add(Double.valueOf(tro.BER));
1014
                    } catch (Exception e) {
1015
                    }
1016
                    try {
1017
                        alBERs.add(Double.valueOf(eproRHTrainTest.meanBERs));
1018
                    } catch (Exception e) {
1019
                    }
1020
                    try {
1021
                        alBERs.add(Double.valueOf(eproBSTrainTest.meanBERs));
1022
                    } catch (Exception e) {
1023
                    }
1024
1025
                    double BERs[] = alBERs.stream().mapToDouble(Double::doubleValue).toArray();
1026
1027
                    StandardDeviation std = new StandardDeviation();
1028
                    std.setData(BERs);
1029
                    double StdBER = std.evaluate();
1030
1031
                    Mean m = new Mean();
1032
                    m.setData(BERs);
1033
                    double averageBER = m.evaluate();
1034
1035
                    //average MAE
1036
                    List<Double> alMAEs = new ArrayList<>();
1037
                    try {
1038
                        alMAEs.add(Double.valueOf(cr.MAE));
1039
                    } catch (Exception e) {
1040
                    }
1041
                    try {
1042
                        alMAEs.add(Double.valueOf(eproRHTrain.meanMAEs));
1043
                    } catch (Exception e) {
1044
                    }
1045
                    try {
1046
                        alMAEs.add(Double.valueOf(eproBSTrain.meanMAEs));
1047
                    } catch (Exception e) {
1048
                    }
1049
                    try {
1050
                        alMAEs.add(Double.valueOf(crLoocv.MAE));
1051
                    } catch (Exception e) {
1052
                    }
1053
                    try {
1054
                        alMAEs.add(Double.valueOf(tro.MAE));
1055
                    } catch (Exception e) {
1056
                    }
1057
                    try {
1058
                        alMAEs.add(Double.valueOf(eproRHTrainTest.meanMAEs));
1059
                    } catch (Exception e) {
1060
                    }
1061
                    try {
1062
                        alMAEs.add(Double.valueOf(eproBSTrainTest.meanMAEs));
1063
                    } catch (Exception e) {
1064
                    }
1065
1066
                    double MAEs[] = alMAEs.stream().mapToDouble(Double::doubleValue).toArray();
1067
1068
                    std = new StandardDeviation();
1069
                    std.setData(MAEs);
1070
                    double StdMAE = std.evaluate();
1071
1072
                    m = new Mean();
1073
                    m.setData(MAEs);
1074
                    double averageMAE = m.evaluate();
1075
1076
                    //average MCC
1077
                    List<Double> alMCCs = new ArrayList<>();
1078
                    try {
1079
                        alMCCs.add(Double.valueOf(cr.MCC));
1080
                    } catch (Exception e) {
1081
                    }
1082
                    try {
1083
                        alMCCs.add(Double.valueOf(eproRHTrain.meanMCCs));
1084
                    } catch (Exception e) {
1085
                    }
1086
                    try {
1087
                        alMCCs.add(Double.valueOf(eproBSTrain.meanMCCs));
1088
                    } catch (Exception e) {
1089
                    }
1090
                    try {
1091
                        alMCCs.add(Double.valueOf(crLoocv.MCC));
1092
                    } catch (Exception e) {
1093
                    }
1094
                    try {
1095
                        alMCCs.add(Double.valueOf(tro.MCC));
1096
                    } catch (Exception e) {
1097
                    }
1098
                    try {
1099
                        alMCCs.add(Double.valueOf(eproRHTrainTest.meanMCCs));
1100
                    } catch (Exception e) {
1101
                    }
1102
                    try {
1103
                        alMCCs.add(Double.valueOf(eproBSTrainTest.meanMCCs));
1104
                    } catch (Exception e) {
1105
                    }
1106
1107
                    double MCCs[] = alMCCs.stream().mapToDouble(Double::doubleValue).toArray();
1108
1109
                    std = new StandardDeviation();
1110
                    std.setData(MCCs);
1111
                    double StdMCC = std.evaluate();
1112
1113
                    m = new Mean();
1114
                    m.setData(MCCs);
1115
                    double averageMCC = m.evaluate();
1116
1117
                    String stats = df.format(averageBER) + "\t"
1118
                            + df.format(StdBER) + "\t"
1119
                            + df.format(averageMAE) + "\t"
1120
                            + df.format(StdMAE) + "\t"
1121
                            + df.format(averageMCC) + "\t"
1122
                            + df.format(StdMCC);
1123
1124
                    //output
1125
                    String bt632 = df.format(bootstrapTrain632plus);
1126
                    if (bt632.equals(1000)) {
1127
                        bt632 = "";
1128
                    }
1129
                    String btt632 = df.format(bootstrapTrainTest632plus);
1130
                    if (bootstrapTrainTest632plus == 1000) {
1131
                        btt632 = "";
1132
                    }
1133
1134
                    lastOutput = out
1135
                            + "\t" + cr.numberOfFeatures
1136
                            + "\t" + cr.toString()
1137
                            + "\t" + loocvOut
1138
                            + "\t" + eproRHTrain.toStringClassification()
1139
                            + "\t" + eproBSTrain.toStringClassification()
1140
                            + "\t" + bt632
1141
                            + "\t" + testResults
1142
                            + "\t" + eproRHTrainTest.toStringClassification()
1143
                            + "\t" + eproBSTrainTest.toStringClassification()
1144
                            + "\t" + btt632
1145
                            + "\t" + stats + "\t" + ao.getRetainedAttributesIdClassInString();
1146
                } else {
1147
                    //statistics
1148
1149
                    //average CC
1150
                    List<Double> alCCs = new ArrayList<>();
1151
                    try {
1152
                        alCCs.add(Double.valueOf(rr.CC));
1153
                    } catch (Exception e) {
1154
                    }
1155
                    try {
1156
                        alCCs.add(Double.valueOf(eproRHTrain.meanCCs));
1157
                    } catch (Exception e) {
1158
                    }
1159
                    try {
1160
                        alCCs.add(Double.valueOf(eproBSTrain.meanCCs));
1161
                    } catch (Exception e) {
1162
                    }
1163
                    try {
1164
                        alCCs.add(Double.valueOf(rrLoocv.CC));
1165
                    } catch (Exception e) {
1166
                    }
1167
                    try {
1168
                        alCCs.add(Double.valueOf(tro.CC));
1169
                    } catch (Exception e) {
1170
                    }
1171
                    try {
1172
                        alCCs.add(Double.valueOf(eproRHTrainTest.meanCCs));
1173
                    } catch (Exception e) {
1174
                    }
1175
                    try {
1176
                        alCCs.add(Double.valueOf(eproBSTrainTest.meanCCs));
1177
                    } catch (Exception e) {
1178
                    }
1179
1180
                    double CCs[] = alCCs.stream().mapToDouble(Double::doubleValue).toArray();
1181
1182
                    StandardDeviation std = new StandardDeviation();
1183
                    std.setData(CCs);
1184
                    double StdCC = std.evaluate();
1185
1186
                    Mean m = new Mean();
1187
                    m.setData(CCs);
1188
                    double averageCC = m.evaluate();
1189
1190
                    //average MAE
1191
                    List<Double> alMAEs = new ArrayList<>();
1192
                    try {
1193
                        alMAEs.add(Double.valueOf(rr.MAE));
1194
                    } catch (Exception e) {
1195
                    }
1196
                    try {
1197
                        alMAEs.add(Double.valueOf(eproRHTrain.meanMAEs));
1198
                    } catch (Exception e) {
1199
                    }
1200
                    try {
1201
                        alMAEs.add(Double.valueOf(eproBSTrain.meanMAEs));
1202
                    } catch (Exception e) {
1203
                    }
1204
                    try {
1205
                        alMAEs.add(Double.valueOf(rrLoocv.MAE));
1206
                    } catch (Exception e) {
1207
                    }
1208
                    try {
1209
                        alMAEs.add(Double.valueOf(tro.MAE));
1210
                    } catch (Exception e) {
1211
                    }
1212
                    try {
1213
                        alMAEs.add(Double.valueOf(eproRHTrainTest.meanMAEs));
1214
                    } catch (Exception e) {
1215
                    }
1216
                    try {
1217
                        alMAEs.add(Double.valueOf(eproBSTrainTest.meanMAEs));
1218
                    } catch (Exception e) {
1219
                    }
1220
1221
                    double MAEs[] = alMAEs.stream().mapToDouble(Double::doubleValue).toArray();
1222
1223
                    std = new StandardDeviation();
1224
                    std.setData(MAEs);
1225
                    double StdMAE = std.evaluate();
1226
1227
                    m = new Mean();
1228
                    m.setData(MAEs);
1229
                    double averageMAE = m.evaluate();
1230
1231
                    //average RMSE
1232
                    List<Double> alRMSEs = new ArrayList<>();
1233
                    try {
1234
                        alRMSEs.add(Double.valueOf(rr.RMSE));
1235
                    } catch (Exception e) {
1236
                    }
1237
                    try {
1238
                        alRMSEs.add(Double.valueOf(eproRHTrain.meanRMSEs));
1239
                    } catch (Exception e) {
1240
                    }
1241
                    try {
1242
                        alRMSEs.add(Double.valueOf(eproBSTrain.meanRMSEs));
1243
                    } catch (Exception e) {
1244
                    }
1245
                    try {
1246
                        alRMSEs.add(Double.valueOf(rrLoocv.RMSE));
1247
                    } catch (Exception e) {
1248
                    }
1249
                    try {
1250
                        alRMSEs.add(Double.valueOf(tro.RMSE));
1251
                    } catch (Exception e) {
1252
                    }
1253
                    try {
1254
                        alRMSEs.add(Double.valueOf(eproRHTrainTest.meanRMSEs));
1255
                    } catch (Exception e) {
1256
                    }
1257
                    try {
1258
                        alRMSEs.add(Double.valueOf(eproBSTrainTest.meanRMSEs));
1259
                    } catch (Exception e) {
1260
                    }
1261
1262
                    double RMSEs[] = alRMSEs.stream().mapToDouble(Double::doubleValue).toArray();
1263
1264
                    std = new StandardDeviation();
1265
                    std.setData(RMSEs);
1266
                    double StdRMSE = std.evaluate();
1267
1268
                    m = new Mean();
1269
                    m.setData(RMSEs);
1270
                    double averageRMSE = m.evaluate();
1271
1272
                    String stats = df.format(averageCC) + "\t"
1273
                            + df.format(StdCC) + "\t"
1274
                            + df.format(averageMAE) + "\t"
1275
                            + df.format(StdMAE) + "\t"
1276
                            + df.format(averageRMSE) + "\t"
1277
                            + df.format(StdRMSE);
1278
                    lastOutput = out
1279
                            + "\t" + rr.numberOfFeatures + "\t" + rr.toString()
1280
                            + "\t" + loocvOut
1281
                            + "\t" + eproRHTrain.toStringRegression()
1282
                            + "\t" + eproBSTrain.toStringRegression()
1283
                            + "\t" + testResults
1284
                            + "\t" + eproRHTrainTest.toStringRegression()
1285
                            + "\t" + eproBSTrainTest.toStringRegression()
1286
                            + "\t" + stats
1287
                            + "\t" + ao.getRetainedAttributesIdClassInString();
1288
                }
1289
1290
                //CREATE ID
1291
                Random r = new Random();
1292
                int randomNumber = r.nextInt(10000 - 10) + 10;
1293
                out = (lastOutput.split("\t")[0] + "_" + lastOutput.split("\t")[2]
1294
                        + "_" + lastOutput.split("\t")[3] + "_" + lastOutput.split("\t")[4] + "_" + lastOutput.split("\t")[5] + "_" + randomNumber);
1295
1296
                // temporary models output
1297
//            String filename = Main.CVfolder + File.separatorChar + out;
1298
//            PrintWriter pw = new PrintWriter(new FileWriter(filename + ".txt"));
1299
//            pw.println("#" + classifier + " " + classifier_options + "\n#Value to maximize or minimize:" + valueToMaximizeOrMinimize);
1300
//            if (isClassification) {
1301
//                pw.print("#ID\tactual\tpredicted\terror\tprobability\n");
1302
//            } else {
1303
//                pw.print("#ID\tactual\tpredicted\terror\n");
1304
//            }
1305
//            pw.print(predictions + "\n\n");
1306
//            pw.print("#Selected Attributes\t(Total attributes:" + numberOfAttributes + ")\n");
1307
//            pw.print(selectedAttributes);
1308
//            pw.close();
1309
//            //save model
1310
//            SerializationHelper.write(filename + ".model", model);
1311
            } else {
1312
                out = "ERROR";
1313
                if (Main.debug) {
1314
                    System.err.println(o);
1315
                }
1316
                if (o.toString().equals("null")) {
1317
                    out += "\t " + classifier + " " + classifier_options + " | " + searchMethod + " | Error probably because of number of features inferior to topX";
1318
                }
1319
            }
1320
1321
        } catch (Exception e) {
1322
            out = "ERROR\t" + classifier + " " + classifier_options + " | " + searchMethod + " | " + e.getMessage();
1323
            cptFailed++;
1324
            if (Main.debug) {
1325
                e.printStackTrace();
1326
            }
1327
        }
1328
        Instant finish = Instant.now();
1329
        long s = Duration.between(start, finish).toMillis() / 1000;
1330
        if (Main.debug) {
1331
            System.out.println("model created in [" + s + "s]");
1332
        }
1333
        return out + "\t" + lastOutput;
1334
    }
1335
1336
    /**
1337
     * get value of interest
1338
     *
1339
     * @param valueWanted
1340
     * @param cr
1341
     * @param rr
1342
     * @return
1343
     */
1344
    private static double getValueToMaximize(String valueWanted, Weka_module.ClassificationResultsObject cr,
1345
            Weka_module.RegressionResultsObject rr) {
1346
1347
        switch (valueWanted.toLowerCase()) {
1348
            //classification
1349
            case "auc":
1350
                return Double.parseDouble(cr.AUC);
1351
            case "pauc":
1352
                return Double.parseDouble(cr.pAUC);
1353
            case "acc":
1354
                return Double.parseDouble(cr.ACC);
1355
            case "sen":
1356
                return Double.parseDouble(cr.TPR);
1357
            case "spe":
1358
                return Double.parseDouble(cr.TNR);
1359
            case "mcc":
1360
                return Double.parseDouble(cr.MCC);
1361
            case "kappa":
1362
                return Double.parseDouble(cr.kappa);
1363
            case "aupcr":
1364
                return Double.parseDouble(cr.AUPRC);
1365
            case "fscore":
1366
                return Double.parseDouble(cr.Fscore);
1367
            case "precision":
1368
                return Double.parseDouble(cr.precision);
1369
            case "recall":
1370
                return Double.parseDouble(cr.recall);
1371
            case "fdr":
1372
                return Double.parseDouble(cr.FDR);//to minimize
1373
            case "ber":
1374
                return Double.parseDouble(cr.BER);//to minimize
1375
            case "tp+fn":
1376
                double sumOfTP_TN = Double.parseDouble(cr.TPR) + Double.parseDouble(cr.TNR);
1377
                return sumOfTP_TN;
1378
            //regression
1379
            case "cc":
1380
                return Double.parseDouble(rr.CC);
1381
            case "mae":
1382
                return Double.parseDouble(rr.MAE);//to minimize
1383
            case "rmse":
1384
                return Double.parseDouble(rr.RMSE);//to minimize
1385
            case "rae":
1386
                return Double.parseDouble(rr.RAE);//to minimize
1387
            case "rrse":
1388
                return Double.parseDouble(rr.RRSE);//to minimize
1389
        }
1390
        return 0;
1391
    }
1392
1393
    /**
1394
     * add to hm weka configurations
1395
     * <["misc.VFI", "-B 0.6", "AUC"],"">
1396
     * <["CostSensitiveClassifier", "-cost-matrix "[0.0 ratio; 100.0 0.0]" -S 1 -W weka.classifiers.misc.VFI -- -B 0.6", "AUC"],"">
1397
     *
1398
     * @param classifier
1399
     * @param options
1400
     * @param costSensitive
1401
     */
1402
    private static void addClassificationToQueue(String classifier, String options) {
1403
        String optimizers[] = Main.classificationOptimizers.split(",");
1404
        String searchmodes[] = Main.searchmodes.split(",");
1405
        if (Main.noFeatureSelection) {
1406
            searchmodes = new String[]{"all"};
1407
        } 
1408
//        else if(weka.myData.numAttributes()>500){
1409
//            
1410
//        }
1411
        
1412
1413
        for (String optimizer : optimizers) {//AUC, ACC, SEN, SPE, MCC, TP+FN, kappa            
1414
            for (String searchmode : searchmodes) {//F, FB, B, BF, topk
1415
                alClassifiers.add(new String[]{classifier, options, optimizer.trim(), searchmode.trim()});
1416
            }
1417
        }
1418
        if (Main.metaCostSensitiveClassifier) {
1419
            options = "-cost-matrix \"[0.0 ratio; 100.0 0.0]\" -S 1 -W weka.classifiers." + classifier + " -- " + options;
1420
            for (String value : optimizers) {
1421
                alClassifiers.add(new String[]{"meta.CostSensitiveClassifier", options, value.trim(), "f"});
1422
                alClassifiers.add(new String[]{"meta.CostSensitiveClassifier", options, value.trim(), "bf"});
1423
            }
1424
        }
1425
    }
1426
1427
    private static void addRegressionToQueue(String classifier, String options) {
1428
        //String valuesToMaximizeOrMinimize[] = new String[]{"CC", "MAE", "RMSE", "RAE", "RRSE"};
1429
        String optimizers[] = Main.regressionOptimizers.split(",");
1430
        String searchmodes[] = Main.searchmodes.split(",");
1431
        if (Main.noFeatureSelection) {
1432
            searchmodes = new String[]{"all"};
1433
        }
1434
        for (String optimizer : optimizers) {//only CC is maximized here
1435
            for (String searchmode : searchmodes) {
1436
                alClassifiers.add(new String[]{classifier, options, optimizer.trim(), searchmode.trim()});
1437
            }
1438
1439
        }
1440
        if (Main.metaAdditiveRegression) {
1441
            options = "-S 1.0 -I 10 -W weka.classifiers." + classifier + " -- " + options;
1442
            for (String value : optimizers) {
1443
                alClassifiers.add(new String[]{"meta.AdditiveRegression", options, value, "F"});
1444
                alClassifiers.add(new String[]{"meta.AdditiveRegression", options, value, "FB"});
1445
            }
1446
        }
1447
    }
1448
1449
    private static class AttributeOject {
1450
1451
        public Integer ID;
1452
        public ArrayList<Integer> alAttributes = new ArrayList<>();
1453
        public Integer Class;
1454
        public ArrayList<Integer> retainedAttributesOnly = new ArrayList<>();
1455
1456
        public AttributeOject() {
1457
1458
        }
1459
1460
        private AttributeOject(ArrayList<Integer> attributes) {
1461
            ID = attributes.get(0);
1462
            Class = attributes.get(attributes.size() - 1);
1463
            attributes.remove(attributes.size() - 1);
1464
            attributes.remove(0);
1465
            alAttributes = attributes;
1466
        }
1467
1468
        private void addNewAttributeToRetainedAttributes(int index) {
1469
            retainedAttributesOnly.add(alAttributes.get(index));
1470
        }
1471
1472
        private String getRetainedAttributesIdClassInString() {
1473
            String attributes = "";
1474
            attributes += (ID) + ",";
1475
            for (Integer retainedAttribute : retainedAttributesOnly) {
1476
                attributes += (retainedAttribute) + ",";
1477
            }
1478
            attributes += Class;
1479
            return attributes;
1480
        }
1481
1482
        private ArrayList getRetainedAttributesIdClassInArrayList() {
1483
            ArrayList<Integer> al = new ArrayList<>();
1484
            al.add(ID);
1485
            for (Integer retainedAttribute : retainedAttributesOnly) {
1486
                al.add(retainedAttribute);
1487
            }
1488
            al.add(Class);
1489
            return al;
1490
        }
1491
1492
        private void changeRetainedAttributes(String featuresToTest) {
1493
            String features[] = featuresToTest.split(",");
1494
            retainedAttributesOnly = new ArrayList<>();
1495
            for (int i = 1; i < features.length - 1; i++) {//skip ID and class indexes
1496
                retainedAttributesOnly.add(Integer.valueOf(features[i]));
1497
            }
1498
        }
1499
1500
        private String getAttributesIdClassInString() {
1501
            String attributes = "";
1502
            attributes += (ID) + ",";
1503
            for (Integer att : alAttributes) {
1504
                attributes += (att) + ",";
1505
            }
1506
            attributes += Class;
1507
            return attributes;
1508
        }
1509
    }
1510
1511
}