--- a +++ b/src/utils/Weka_module.java @@ -0,0 +1,2925 @@ +/* + * Weka + */ +package utils; + +import biodiscml.Main; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.PrintWriter; +import java.text.DecimalFormat; +import java.text.DecimalFormatSymbols; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Random; +import java.util.TreeMap; +import java.util.concurrent.atomic.AtomicIntegerArray; +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.integration.TrapezoidIntegrator; +import org.apache.commons.math3.analysis.interpolation.SplineInterpolator; +import org.apache.commons.math3.analysis.interpolation.UnivariateInterpolator; +import org.apache.commons.math3.exception.MaxCountExceededException; +import static utils.utils.getMean; +import weka.attributeSelection.AttributeSelection; +import weka.classifiers.Classifier; +import weka.classifiers.Evaluation; +import weka.classifiers.evaluation.ThresholdCurve; +import weka.classifiers.evaluation.output.prediction.PlainText; +import weka.classifiers.meta.FilteredClassifier; +import weka.core.Attribute; +import weka.core.Instance; +import weka.core.Instances; +import weka.core.Range; +import weka.core.UnsupportedAttributeTypeException; +import weka.core.Utils; +import weka.core.converters.ArffLoader; +import weka.core.converters.ArffSaver; +import weka.core.converters.CSVLoader; +import weka.core.converters.CSVSaver; +import weka.core.converters.ConverterUtils; +import weka.filters.Filter; +import weka.filters.supervised.instance.StratifiedRemoveFolds; +import weka.filters.unsupervised.attribute.Remove; +import weka.filters.unsupervised.attribute.Reorder; +import weka.filters.unsupervised.attribute.StringToNominal; +import weka.filters.unsupervised.instance.RemoveRange; + +/** + * + * @author Mickael + */ +public class Weka_module { + + public DecimalFormat df = new DecimalFormat(); + public ArrayList<String> alPrint = new ArrayList<>(); + private File CSVFile; + public String ARFFfile; + public boolean debug = true; + public Instances myData; + + public Weka_module() { + df.setMaximumFractionDigits(3); + DecimalFormatSymbols dec = new DecimalFormatSymbols(); + dec.setDecimalSeparator('.'); + df.setGroupingUsed(false); + df.setDecimalFormatSymbols(dec); +// svm.svm_set_print_string_function(new libsvm.svm_print_interface() { +// @Override +// public void print(String s) { +// } // Disables svm output +// }); + + } + + /** + * keyword to search in attribute names to remove them + * + * @param keyword + * @param attributeSelection + */ + public void csvToArffRegressionWithFilter(String keyword, boolean attributeSelection) { + try { + //load csv + if (debug) { + System.out.println("loading csv"); + } + CSVLoader csv = new CSVLoader(); + csv.setSource(CSVFile); + csv.setMissingValue("?"); + //csv.setFieldSeparator("\t"); + Instances data = csv.getDataSet(); + + //remove data having the keyword + if (debug) { + System.out.println("cleaning csv"); + } + String[] opts = new String[2]; + opts[0] = "-R"; + String indexes = ""; + for (int i = 0; i < data.numAttributes(); i++) { + if (data.attribute(i).name().contains(keyword)) { + //System.out.println("removing " + data.attribute(i).name()); + indexes += (i + 1) + ","; + } + } + opts[1] = indexes; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + r.setOptions(opts); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + + if (debug) { + System.out.println("saving full arff"); + } + ArffSaver arff = new ArffSaver(); + arff.setInstances(data); + ARFFfile = CSVFile.toString().replace(".csv", ".arff"); + arff.setFile(new File(ARFFfile)); + arff.writeBatch(); + + //filter out useless data + if (debug) { + System.out.println("attribute selection"); + } + ArffLoader arf = new ArffLoader(); + arf.setSource(new File(ARFFfile)); + data = arf.getDataSet(); + + if (attributeSelection) { + //weka.filters.supervised.attribute.AttributeSelection -E "weka.attributeSelection.CfsSubsetEval " -S "weka.attributeSelection.BestFirst -D 1 -N 5" + weka.filters.supervised.attribute.AttributeSelection select = new weka.filters.supervised.attribute.AttributeSelection(); + String options = "-E " + + "\"weka.attributeSelection.CfsSubsetEval \" -S \"weka.attributeSelection.BestFirst -D 1 -N 5\""; + select.setOptions(weka.core.Utils.splitOptions((options))); + select.setInputFormat(data); + data = Filter.useFilter(data, select); + } + + //save csv to arff + if (debug) { + System.out.println("saving CfsSubsetEval Best Firsts arff"); + } + arff = new ArffSaver(); + arff.setInstances(data); + ARFFfile = CSVFile.toString().replace(".csv", ".CfsSubset_BestFirst.arff"); + arff.setFile(new File(ARFFfile)); + arff.writeBatch(); + + } catch (Exception e) { + e.printStackTrace(); + } + } + + public void csvToArffRegression() { + if (Main.debug) { + System.out.println("\n\nConvert " + CSVFile + " to arff"); + } + File f = new File(CSVFile.toString().replace(".csv", ".arff")); + f.delete(); + try { + //load csv + CSVLoader csv = new CSVLoader(); + csv.setSource(CSVFile); + csv.setMissingValue("?"); + //csv.setFieldSeparator("\t"); + Instances data = csv.getDataSet(); + + //save csv to arff + ArffSaver arff = new ArffSaver(); + arff.setInstances(data); + ARFFfile = CSVFile.toString().replace(".csv", ".arff"); + arff.setFile(new File(ARFFfile)); + arff.writeBatch(); + + } catch (Exception e) { + e.printStackTrace(); + } + } + + public void csvToArff(boolean classification) { + if (Main.debug) { + System.out.println("Convert " + CSVFile + " to arff"); + } + File f = new File(CSVFile.toString().replace(".csv", ".arff")); + f.delete(); + try { + //load csv + CSVLoader csv = new CSVLoader(); + csv.setSource(CSVFile); + csv.setNominalAttributes("1"); //change ID to nominal + csv.setMissingValue("?"); + String separator = utils.detectSeparator(CSVFile.toString()); + csv.setFieldSeparator(separator); + Instances data = csv.getDataSet(); + data = convertStringsToNominal(data); + + //change numeric to nominal class + if (classification) { + weka.filters.unsupervised.attribute.NumericToNominal n + = new weka.filters.unsupervised.attribute.NumericToNominal(); + String[] opts = new String[2]; + opts[0] = "-R"; + opts[1] = "last"; + n.setOptions(opts); + n.setInputFormat(data); + data = Filter.useFilter(data, n); + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + + //class order + if (data.classAttribute().enumerateValues().hasMoreElements()) { + if (data.classAttribute().enumerateValues().nextElement().equals("false") + || data.classAttribute().enumerateValues().nextElement().equals("0")) { + weka.filters.unsupervised.attribute.SwapValues s = new weka.filters.unsupervised.attribute.SwapValues(); + s.setOptions(weka.core.Utils.splitOptions(("-C last -F first -S last"))); + s.setInputFormat(data); + data = Filter.useFilter(data, s); + } + } + } + + //save csv to arff + ArffSaver arff = new ArffSaver(); + arff.setInstances(data); + ARFFfile = CSVFile.toString().replace(".csv", ".arff"); + arff.setFile(new File(ARFFfile)); + arff.writeBatch(); + } catch (Exception e) { + e.printStackTrace(); + } + } + + public ArrayList<String> trainRegression() { + ArrayList<String> al = new ArrayList<>(); + + try { + // load data + ConverterUtils.DataSource source = new ConverterUtils.DataSource(ARFFfile); + Instances data = source.getDataSet(); + int numberOfFolds = 10; + //reduc numberOfFolds if needed + if (data.numInstances() < numberOfFolds) { + numberOfFolds = data.numInstances(); + } + + //set last attribute as index + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + + // train model + ////weka.classifiers.functions.GaussianProcesses +// weka.classifiers.functions.GaussianProcesses model = new weka.classifiers.functions.GaussianProcesses(); +// String options = "-L 1.0 -N 0 -K \"weka.classifiers.functions.supportVector.NormalizedPolyKernel -C 250007 -E 2.0\""; +// model.setOptions(weka.core.Utils.splitOptions((options))); + ////weka.classifiers.functions.SMOreg + weka.classifiers.functions.SMOreg model = new weka.classifiers.functions.SMOreg(); + String options = "-C 1.0 -N 0 -I \"weka.classifiers.functions.supportVector.RegSMO -L 0.001 -W 1 -P 1.0E-12\" -K \"weka.classifiers.functions.supportVector.NormalizedPolyKernel -C 250007 -E 2.0\""; + model.setOptions(weka.core.Utils.splitOptions((options))); + + // cross validation 10 times on the model + Evaluation eval = new Evaluation(data); + eval.crossValidateModel(model, data, numberOfFolds, new Random(1)); + + //output classification results + //System.out.println(eval.toSummaryString()); + al.add(Utils.doubleToString(eval.correlationCoefficient(), 12, 4)); + al.add(Utils.doubleToString(eval.meanAbsoluteError(), 12, 4)); + al.add(Utils.doubleToString(eval.rootMeanSquaredError(), 12, 4)); + al.add(Utils.doubleToString(eval.relativeAbsoluteError(), 12, 4)); + al.add(Utils.doubleToString(eval.rootRelativeSquaredError(), 12, 4)); + al.add("" + (data.numAttributes() - 1)); // number of genes + + } catch (Exception e) { + e.printStackTrace(); + } + return al; + } + + /** + * short test on 10% of instances + * + * @param attributesToUse + * @return + */ + public Object shortTestTrainClassifier(String classifier, + String classifier_options, + String attributesToUse, + Boolean classification) { + + try { + // load data + Instances data = myData; + //get a sample of instances + //// randomize data + data.randomize(new Random()); + + //// Percent split + int trainSize = (int) Math.round(data.numInstances() * 95 / 100); + int testSize = data.numInstances() - trainSize; + data = new Instances(data, trainSize, testSize); + + //set last attribute as index + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + data.attribute(data.numAttributes() - 1).isNumeric(); + /** + * if we are using Vote, it means it already contains the features + * to use for each classifier + */ + String configuration = classifier + " " + classifier_options; + if (!classifier.contains("meta.Vote")) { + //keep only attributesToUse. ID is still here + String options = ""; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + options = "-V -R " + attributesToUse; + r.setOptions(weka.core.Utils.splitOptions((options))); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + data.setClassIndex(data.numAttributes() - 1); + + //create weka configuration string + //remove ID from classification + String filterID = "" + + "weka.classifiers.meta.FilteredClassifier " + + "-F \"weka.filters.unsupervised.attribute.Remove -R 1\" " + + "-W weka.classifiers."; + + //create config + configuration = filterID + "" + classifier + " -- " + classifier_options; + } + + //if cost sensitive case + if (classifier.contains("CostSensitiveClassifier")) { + int falseExemples = data.attributeStats(data.classIndex()).nominalCounts[1]; + int trueExemples = data.attributeStats(data.classIndex()).nominalCounts[0]; + double ratio = (double) trueExemples / (double) falseExemples; + classifier_options = classifier_options.replace("ratio", df.format(ratio)); + } + + //train + final String config[] = Utils.splitOptions(configuration); + String classname = config[0]; + config[0] = ""; + Classifier model = null; + + model = (Classifier) Utils.forName(Classifier.class, classname, config); + + //evaluation + if (Main.debug) { + System.out.println("\tShort test of model on 5% of the data..."); + } + + //build model to save it + model.buildClassifier(data); + if (Main.debug) { + System.out.println("PASSED"); + } + return model; + + } catch (Exception e) { + if (Main.debug) { + System.out.println("FAILED"); + } + String message = "[model error] " + classifier + " " + classifier_options + " | " + e.getMessage(); + if (Main.debug) { +// if (!e.getMessage().contains("handle") && !e.getMessage().contains("supported")) { +// +// } + e.printStackTrace(); + System.err.println(message); + + } + return message; + + } + + } + + /** + * http://bayesianconspiracy.blogspot.ca/2009/10/usr-bin-env-groovy-import-weka_10.html + * + * @param attributesToUse + * @return + */ + public Object trainClassifier(String classifier, + String classifier_options, + String attributesToUse, + Boolean classification, int numberOfFolds) { + + try { + // load data + Instances data = myData; + //reduc numberOfFolds if needed + if (data.numInstances() < numberOfFolds) { + numberOfFolds = data.numInstances(); + } + //set last attribute as index + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + data.attribute(data.numAttributes() - 1).isNumeric(); + /** + * if we are using Vote, it means it already contains the features + * to use for each classifier + */ + String configuration = classifier + " " + classifier_options; + if (!classifier.contains("meta.Vote")) { + //keep only attributesToUse. ID is still here + String options = ""; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + options = "-V -R " + attributesToUse; + r.setOptions(weka.core.Utils.splitOptions((options))); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + data.setClassIndex(data.numAttributes() - 1); + + //create weka configuration string + //remove ID from classification + String filterID = "" + + "weka.classifiers.meta.FilteredClassifier " + + "-F \"weka.filters.unsupervised.attribute.Remove -R 1\" " + + "-W weka.classifiers."; + + //create config + configuration = filterID + "" + classifier + " -- " + classifier_options; + } + + //if cost sensitive case + if (classifier.contains("CostSensitiveClassifier")) { + int falseExemples = data.attributeStats(data.classIndex()).nominalCounts[1]; + int trueExemples = data.attributeStats(data.classIndex()).nominalCounts[0]; + double ratio = (double) trueExemples / (double) falseExemples; + classifier_options = classifier_options.replace("ratio", df.format(ratio)); + } + + //train + final String config[] = Utils.splitOptions(configuration); + String classname = config[0]; + config[0] = ""; + Classifier model = null; + + model = (Classifier) Utils.forName(Classifier.class, classname, config); + + //evaluation + Instant start = null; + if (Main.debug2) { + System.out.print("\tEvaluation of model with 10CV..."); + start = Instant.now(); + } + + Evaluation eval = new Evaluation(data); + StringBuffer sb = new StringBuffer(); + PlainText pt = new PlainText(); + pt.setBuffer(sb); + + //10 fold cross validation + eval.crossValidateModel(model, data, numberOfFolds, new Random(1), pt, new Range("first,last"), true); + + if (Main.debug2) { + Instant finish = Instant.now(); + long s = Duration.between(start, finish).toMillis(); + System.out.println(" in " + s + "ms"); + } + + //build model to save it + model.buildClassifier(data); + // System.out.println(model.toString().split("\n")[0]); + //return + if (classification) { + return new ClassificationResultsObject(eval, sb, data, model); + } else { + return new RegressionResultsObject(eval, sb, data, model); + } + + } catch (Exception e) { + String message = "[model error] " + classifier + " " + classifier_options + " | " + e.getMessage(); + if (Main.debug) { +// if (!e.getMessage().contains("handle") && !e.getMessage().contains("supported")) { +// +// } + e.printStackTrace(); + System.err.println(message); + + } + return message; + + } catch (Error err) { + String message = "[model error] " + classifier + " " + classifier_options + " | " + err.getMessage(); + if (Main.debug) { +// if (!e.getMessage().contains("handle") && !e.getMessage().contains("supported")) { +// +// } + err.printStackTrace(); + System.err.println(message); + + } + return message; + } + + } + + /** + * HoldOut + * + * @param classifier + * @param classifier_options + * @param attributesToUse + * @param classification + * @return + */ + public Object trainClassifierHoldOutValidation(String classifier, String classifier_options, + String attributesToUse, Boolean classification, int seed) { + try { + // load data + Instances data = myData; + + //set class index + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + + String configuration = classifier + " " + classifier_options; + if (!classifier.contains("meta.Vote") && attributesToUse != null) { + //keep only attributesToUse. ID (if present) is still here + String options = ""; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + options = "-V -R " + attributesToUse; + r.setOptions(weka.core.Utils.splitOptions((options))); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + data.setClassIndex(data.numAttributes() - 1); + + } + if (!classifier.contains("meta.Vote")) { + //create weka configuration string + String filterID = "weka.classifiers.meta.FilteredClassifier " + + "-F \"weka.filters.unsupervised.attribute.Remove -R 1\" " + + "-W weka.classifiers."; + configuration = filterID + "" + classifier + " -- " + classifier_options; + } + + // randomize data + data.randomize(new Random(seed)); + + // Percent split + int trainSize = (int) Math.round(data.numInstances() * 66 / 100); + int testSize = data.numInstances() - trainSize; + Instances train = new Instances(data, 0, trainSize); + Instances test = new Instances(data, trainSize, testSize); + + //if cost sensitive case + if (classifier.contains("CostSensitiveClassifier")) { + int falseExemples = train.attributeStats(train.classIndex()).nominalCounts[1]; + int trueExemples = train.attributeStats(train.classIndex()).nominalCounts[0]; + double ratio = (double) trueExemples / (double) falseExemples; + classifier_options = classifier_options.replace("ratio", df.format(ratio)); + } + + //train + String config[] = Utils.splitOptions(configuration); + String classname = config[0]; + config[0] = ""; + Classifier model = (Classifier) Utils.forName(Classifier.class, classname, config); + StringBuffer sb = new StringBuffer(); + + //build model to save it + try { + model.buildClassifier(train); + } catch (Exception exception) { + //crashing, try to remove ID and run again + train.deleteAttributeAt(0); + test.deleteAttributeAt(0); + model.buildClassifier(train); + } + + //evaluation + Instant start = null; + if (Main.debug2) { + System.out.print("\tEvaluation of model [" + seed + "]..."); + start = Instant.now(); + } + + Evaluation eval = new Evaluation(data); + eval.evaluateModel(model, test); + + if (Main.debug2) { + Instant finish = Instant.now(); + long s = Duration.between(start, finish).toMillis(); + System.out.println(" in " + s + "ms"); + } + + //return + if (classification) { + return new ClassificationResultsObject(eval, sb, data, model); + } else { + return new RegressionResultsObject(eval, sb, data, model); + } + } catch (Exception e) { + if (Main.debug) { + System.err.println(e.getMessage() + " for " + classifier + " " + classifier_options); + e.printStackTrace(); + } + + return null; + } + + } + + /** + * Bootstrap + * + * @param classifier + * @param classifier_options + * @param attributesToUse + * @param classification + * @return + */ + public Object trainClassifierBootstrap(String classifier, String classifier_options, + String attributesToUse, Boolean classification, int seed) { + try { + // load data + Instances data = myData; + + //set class index + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + + String configuration = classifier + " " + classifier_options; + if (!classifier.contains("meta.Vote") && attributesToUse != null) { + //keep only attributesToUse. ID (if present) is still here + String options = ""; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + options = "-V -R " + attributesToUse; + r.setOptions(weka.core.Utils.splitOptions((options))); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + data.setClassIndex(data.numAttributes() - 1); + + } + if (!classifier.contains("meta.Vote")) { + //create weka configuration string + String filterID = "weka.classifiers.meta.FilteredClassifier " + + "-F \"weka.filters.unsupervised.attribute.Remove -R 1\" " + + "-W weka.classifiers."; + configuration = filterID + "" + classifier + " -- " + classifier_options; + } + + Random r = new Random(seed); + //train + String config[] = Utils.splitOptions(configuration); + String classname = config[0]; + config[0] = ""; + Classifier model = (Classifier) Utils.forName(Classifier.class, classname, config); + StringBuffer sb = new StringBuffer(); + + //evaluation + Instant start = null; + if (Main.debug2) { + System.out.print("\tEvaluation of model [" + seed + "]..."); + start = Instant.now(); + } + + // Custom sampling (100%, with replacement) + ArrayList<Instance> al_trainSet = new ArrayList<>(data.size()); // Empty list (add one-by-one) + ArrayList<Instance> al_testSet = new ArrayList<>(data); // Full (remove one-by-one) + for (int j = 0; j < data.size(); j++) { + // Random select instance + Instance instance = data.get(r.nextInt(data.size())); + // Add to TRAIN, remove from TEST + al_trainSet.add(instance); + al_testSet.remove(instance); + } + //prepare train and test sets + Instances trainSet = new Instances(data, al_trainSet.size()); + trainSet.addAll(al_trainSet); + Instances testSet = new Instances(data, al_testSet.size()); + testSet.addAll(al_testSet); + + //train the train set + try { + model.buildClassifier(trainSet); + } catch (Exception exception) { + // crashing probably because of presence of ID, removing it + trainSet.deleteAttributeAt(0); + testSet.deleteAttributeAt(0); + model.buildClassifier(trainSet); + } + + // Test the test set + Evaluation eval = new Evaluation(data); + eval.evaluateModel(model, testSet); + + if (Main.debug2) { + Instant finish = Instant.now(); + long s = Duration.between(start, finish).toMillis(); + double d = s / 1000; + System.out.println(" in " + s + "ms"); + } + + //return + if (classification) { + return new ClassificationResultsObject(eval, sb, data, model); + } else { + return new RegressionResultsObject(eval, sb, data, model); + } + } catch (Exception e) { + if (Main.debug) { + e.printStackTrace(); + } + //System.err.println(e.getMessage() + " for " + classifier + " " + classifier_options); + return null; + } + + } + + /** + * Bootstrap .632+ + * + * @param classifier + * @param classifier_options + * @param attributesToUse + * @return + */ + public Double trainClassifierBootstrap632plus(String classifier, String classifier_options, + String attributesToUse) { + try { + // load data + Instances data = myData; + + //set class index + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + + String configuration = classifier + " " + classifier_options; + if (!classifier.contains("meta.Vote") && attributesToUse != null) { + //keep only attributesToUse. ID (if present) is still here + String options = ""; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + options = "-V -R " + attributesToUse; + r.setOptions(weka.core.Utils.splitOptions((options))); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + data.setClassIndex(data.numAttributes() - 1); + + //create weka configuration string + String filterID = "weka.classifiers.meta.FilteredClassifier " + + "-F \"weka.filters.unsupervised.attribute.Remove -R 1\" " + + "-W weka.classifiers."; + configuration = filterID + "" + classifier + " -- " + classifier_options; + } + + //train model + final String config[] = Utils.splitOptions(configuration); + String classname = config[0]; + config[0] = ""; + Classifier model = null; + + model = (Classifier) Utils.forName(Classifier.class, classname, config); + + //build model + model.buildClassifier(data); + + // evaluate model + Evaluation eval = new Evaluation(data); + eval.evaluateModel(model, data); + + // Apparent error rate + Double err = eval.errorRate(); + + // Calculate Leave One Out (LOO) Bootstrap + AtomicIntegerArray p_l = new AtomicIntegerArray(data.numClasses()); + AtomicIntegerArray q_l = new AtomicIntegerArray(data.numClasses()); + double sum = 0; + + for (int i = 0; i < Main.bootstrapAndRepeatedHoldoutFolds; i++) { + Random r = new Random(); + + // Custom sampling (100%, with replacement) + ArrayList<Instance> al_trainSet = new ArrayList<>(data.size()); // Empty list (add one-by-one) + ArrayList<Instance> al_testSet = new ArrayList<>(data); // Full (remove one-by-one) + for (int j = 0; j < data.size(); j++) { + // Random select instance + Instance instance = data.get(r.nextInt(data.size())); + // Add to TRAIN, remove from TEST + al_trainSet.add(instance); + al_testSet.remove(instance); + } + //train the train set + Instances trainSet = new Instances(data, al_trainSet.size()); + trainSet.addAll(al_trainSet); + model.buildClassifier(trainSet); + + // Test the test set + Instances testSet = new Instances(data, al_testSet.size()); + testSet.addAll(al_testSet); + Evaluation evaluation = new Evaluation(data); + evaluation.evaluateModel(model, testSet); + + // total error rates + sum += evaluation.errorRate(); + + //GAMMA + double[][] confusionMatrix = evaluation.confusionMatrix(); + for (int l = 0; l < data.numClasses(); l++) { + int p_tmp = 0, q_tmp = 0; + for (int n = 0; n < data.numClasses(); n++) { + // Sum for l-th class + p_tmp += confusionMatrix[l][n]; + q_tmp += confusionMatrix[n][l]; + } + + // Add data for l-th class + p_l.addAndGet(l, p_tmp); + q_l.addAndGet(l, q_tmp); + } + } + double Err1 = sum / Main.bootstrapAndRepeatedHoldoutFolds; + + // Plain 0.632 bootstrap + Double Err632 = .368 * err + .632 * Err1; + + // GAMA + final double observations = data.size() * Main.bootstrapAndRepeatedHoldoutFolds; + double gama = 0; + for (int l = 0; l < data.numClasses(); l++) { + // Normalize numbers -> divide by number of all observations (repeats * dataset size) + gama += ((double) p_l.get(l) / observations) * (1 - ((double) q_l.get(l) / observations)); + } + + // Relative overfitting rate (R) + double R = (Err1 - err) / (gama - err); + + // Modified variables (according to original journal article) + double Err1_ = Double.min(Err1, gama); + double R_ = R; + + // R can fall out of [0, 1] -> set it to 0 + if (!(Err1 > err && gama > err)) { + R_ = 0; + } + + // The 0.632+ bootstrap (as used in original article) + double Err632plus = Err632 + (Err1_ - err) * (.368 * .632 * R_) / (1 - .368 * R_); + + return Err632plus; + + } catch (Exception e) { + if (Main.debug) { + e.printStackTrace(); + } + return -1.0; + } + } + + public void attributeSelectionByRelieFFAndSaveToCSV(String outfile) { + + // load data + try { + ConverterUtils.DataSource source = new ConverterUtils.DataSource(ARFFfile); + Instances data = source.getDataSet(); + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + //get id attribute + Attribute id = data.attribute(0); + + weka.filters.supervised.attribute.AttributeSelection select = new weka.filters.supervised.attribute.AttributeSelection(); + + //filter based on ranker score threshold + String limit = " -N " + Main.maxNumberOfSelectedFeatures; + String options = "-E \"weka.attributeSelection.ReliefFAttributeEval -M -1 -D 1 -K 10\" -S \"weka.attributeSelection.Ranker -T 0.01" + limit + "\""; + select.setOptions(weka.core.Utils.splitOptions((options))); + select.setInputFormat(data); + Instances filteredData = Filter.useFilter(data, select); + + System.out.println("Total attributes: " + (filteredData.numAttributes() - 1)); + + if (filteredData.numAttributes() == 0) { + System.out.println("Not enough attributes, probably all non-informative. So keeping all and ranking them, but expect low training performance"); + options = "-E \"weka.attributeSelection.ReliefFAttributeEval -M -1 -D 1 -K 10\" -S \"weka.attributeSelection.Ranker -T -1 -N -1\""; + select.setOptions(weka.core.Utils.splitOptions((options))); + select.setInputFormat(data); + filteredData = Filter.useFilter(data, select); + } + + //add identifier if it has been lost after relieff + if (filteredData.attribute(id.name()) == null) { + filteredData.insertAttributeAt(id, 0); + filteredData.attribute(0).isNominal(); + } + //move identifier to index 0 + if (filteredData.attribute(id.name()) != null && !filteredData.attribute(0).equals(id)) { + //find index of ID + int idIndex = 0; + for (int i = 0; i < filteredData.numAttributes(); i++) { + if (filteredData.attribute(i).equals(id)) { + idIndex = i; + } + } + //delete ID + filteredData.deleteAttributeAt(idIndex); + //reinsert ID + filteredData.insertAttributeAt(id, 0); + } + + //restore IDs + for (int i = 0; i < filteredData.numInstances(); i++) { + filteredData.instance(i).setValue(0, (String) id.value(i)); + } + //save data as csv + CSVSaver csv = new CSVSaver(); + csv.setInstances(filteredData); + + csv.setFile(new File(outfile)); + if (new File(outfile.replace(".csv", ".arff")).exists()) { + new File(outfile.replace(".csv", ".arff")).delete(); + } + csv.writeBatch(); + myData = filteredData; + + } catch (Exception e) { + e.printStackTrace(); + } + } + + public void attributeSelectionByRelieFFAndSaveToCSVbigData(String outfile) { + + // load data + try { + ConverterUtils.DataSource source = new ConverterUtils.DataSource(ARFFfile); + Instances data = source.getDataSet(); + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + //get id attribute + Attribute id = data.attribute(0); + + weka.filters.supervised.attribute.AttributeSelection select = new weka.filters.supervised.attribute.AttributeSelection(); + + //go parallel one set of features at a time + //filter based on ranker score threshold + String limit = " -N " + Main.maxNumberOfSelectedFeatures; + String options = "-E \"weka.attributeSelection.ReliefFAttributeEval -M -1 -D 1 -K 10\" -S \"weka.attributeSelection.Ranker -T 0.01" + limit + "\""; + select.setOptions(weka.core.Utils.splitOptions((options))); + select.setInputFormat(data); + Instances filteredData = Filter.useFilter(data, select); + //restoring instances IDs as first attribute +// Attribute attributeInstances = filteredData.attribute("Instance"); +// filteredData.deleteAttributeAt(filteredData.attribute("Instance").index()); +// filteredData.insertAttributeAt(attributeInstances, 0); +// +// Enumeration e = attributeInstances.enumerateValues(); +// int cpt = 0; +// while (e.hasMoreElements()) { +// String s = (String) e.nextElement(); +// filteredData.instance(cpt).setValue(0, s); +// cpt++; +// } +// filteredData.attribute(0).isNominal(); + + System.out.println("Total attributes: " + (filteredData.numAttributes() - 1)); + + if (filteredData.numAttributes() == 0) { + System.out.println("Not enough attributes, probably all non-informative. So keeping all and ranking them, but expect low training performance"); + options = "-E \"weka.attributeSelection.ReliefFAttributeEval -M -1 -D 1 -K 10\" -S \"weka.attributeSelection.Ranker -T -1 -N -1\""; + select.setOptions(weka.core.Utils.splitOptions((options))); + select.setInputFormat(data); + filteredData = Filter.useFilter(data, select); + } + + //add identifier if it has been lost after relieff + if (filteredData.attribute(id.name()) == null) { + filteredData.insertAttributeAt(id, 0); + filteredData.attribute(0).isNominal(); + } + //move identifier to index 0 + if (filteredData.attribute(id.name()) != null && !filteredData.attribute(0).equals(id)) { + //find index of ID + int idIndex = 0; + for (int i = 0; i < filteredData.numAttributes(); i++) { + if (filteredData.attribute(i).equals(id)) { + idIndex = i; + } + } + //delete ID + filteredData.deleteAttributeAt(idIndex); + //reinsert ID + filteredData.insertAttributeAt(id, 0); + } + + //restore IDs + for (int i = 0; i < filteredData.numInstances(); i++) { + filteredData.instance(i).setValue(0, (String) id.value(i)); + } + + //save data as csv + CSVSaver csv = new CSVSaver(); + csv.setInstances(filteredData); + + csv.setFile(new File(outfile)); + if (new File(outfile.replace(".csv", ".arff")).exists()) { + new File(outfile.replace(".csv", ".arff")).delete(); + } + csv.writeBatch(); + myData = filteredData; + + } catch (Exception e) { + e.printStackTrace(); + } + } + + public void attributeSelectionByInfoGainRankingAndSaveToCSV(String outfile) { + Instances data = myData; + + //load data + try { + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + //get id attribute + Attribute id = data.attribute(0); + + // info gain filter on value (remove attributes with 0 merit) + weka.filters.supervised.attribute.AttributeSelection select = new weka.filters.supervised.attribute.AttributeSelection(); + String options = "-E \"weka.attributeSelection.InfoGainAttributeEval \" -S \"weka.attributeSelection.Ranker -T 0.001\""; + select.setOptions(weka.core.Utils.splitOptions((options))); + select.setInputFormat(data); + + Instances filteredData = Filter.useFilter(data, select); + + System.out.println("Total attributes: " + (filteredData.numAttributes() - 1)); + + // filter on number of maxNumberOfSelectedFeatures + if (filteredData.numAttributes() > Main.maxNumberOfSelectedFeatures) { + System.out.println("Too many attributes, only keep best " + Main.maxNumberOfSelectedFeatures); + String limit = " -N " + Main.maxNumberOfSelectedFeatures; + options = "-E \"weka.attributeSelection.InfoGainAttributeEval \" -S \"weka.attributeSelection.Ranker -T 0.001" + limit + "\""; + select.setOptions(weka.core.Utils.splitOptions((options))); + select.setInputFormat(filteredData); + filteredData = Filter.useFilter(filteredData, select); + } + + if (filteredData.numAttributes() <= 10) { + System.out.println("Not enough attributes, probably all non-informative. So keeping all and ranking them, but expect very low performances"); + options = "-E \"weka.attributeSelection.InfoGainAttributeEval \" -S \"weka.attributeSelection.Ranker -T -1 -N -1\""; + select.setOptions(weka.core.Utils.splitOptions((options))); + select.setInputFormat(data); + filteredData = Filter.useFilter(data, select); + } + //add identifier if it has been lost after information gain + if (filteredData.attribute(id.name()) == null) { + filteredData.insertAttributeAt(id, 0); + } + //move identifier to index 0 + if (filteredData.attribute(id.name()) != null && !filteredData.attribute(0).equals(id)) { + //find index of ID + int idIndex = 0; + for (int i = 0; i < filteredData.numAttributes(); i++) { + if (filteredData.attribute(i).equals(id)) { + idIndex = i; + } + } + Reorder r = new Reorder(); + idIndex = idIndex + 1; + if (idIndex != (filteredData.numAttributes() - 1)) { + options = "-R " + idIndex + ",first-" + (idIndex - 1) + "," + (idIndex + 1) + "-last"; + } else { + options = "-R " + idIndex + ",first-" + (idIndex - 1) + "," + (idIndex + 1); + } + r.setOptions(weka.core.Utils.splitOptions((options))); + r.setInputFormat(filteredData); + filteredData = Filter.useFilter(filteredData, r); + } + + //in case of complete lost ID (replaced by ?) + if (filteredData.instance(0).toString().startsWith("?")) { + for (int i = 0; i < filteredData.numInstances(); i++) { + filteredData.instance(i).setValue(0, data.attribute(0).value((int) data.instance(i).value(0))); + } + } + + //save data as csv + CSVSaver csv = new CSVSaver(); + csv.setInstances(filteredData); + + csv.setFile(new File(outfile)); + if (new File(outfile.replace(".csv", ".arff")).exists()) { + new File(outfile.replace(".csv", ".arff")).delete(); + } + csv.writeBatch(); + myData = filteredData; + + //evaluation.crossValidateModel(classifier, data, 10, new Random(1)); + } catch (Exception e) { + e.printStackTrace(); + System.err.println(e.getMessage() + "\nUnable to execute InfoGain. Trying directly Ranking"); + featureRankingForClassification(); + } + } + + public void attributeSelection(String evaluator, String outfileExtension) { + // load data + try { + Instances data = myData; + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } +// //remove patient ID +// if (myData.attribute(0).name().toLowerCase().contains("patient")) { +// data.deleteAttributeAt(0); +// } + + weka.filters.supervised.attribute.AttributeSelection select = new weka.filters.supervised.attribute.AttributeSelection(); + select.setOptions(weka.core.Utils.splitOptions((evaluator))); + select.setInputFormat(data); + data = Filter.useFilter(data, select); + + //save data as arff + ArffSaver arffsav = new ArffSaver(); + arffsav.setInstances(data); + ARFFfile = ARFFfile.replace(".arff", "." + outfileExtension + ".arff"); + arffsav.setFile(new File(ARFFfile)); + arffsav.writeBatch(); + + //evaluation.crossValidateModel(classifier, data, 10, new Random(1)); + } catch (Exception e) { + e.printStackTrace(); + } + } + + public AttributeSelection featureRankingForClassification() { + try { + ConverterUtils.DataSource source = new ConverterUtils.DataSource(ARFFfile); + Instances data = source.getDataSet(); + data = convertStringsToNominal(data); + data.numAttributes(); + AttributeSelection attrsel = new AttributeSelection(); + weka.attributeSelection.InfoGainAttributeEval eval = new weka.attributeSelection.InfoGainAttributeEval(); + + weka.attributeSelection.Ranker rank = new weka.attributeSelection.Ranker(); + rank.setOptions( + weka.core.Utils.splitOptions( + "-T 0.001 -N -1")); + attrsel.setEvaluator(eval); + attrsel.setSearch(rank); + + attrsel.SelectAttributes(data); + return attrsel; + } catch (Exception e) { + e.printStackTrace(); + System.err.println("Check your input files. Probably missing classes"); + System.exit(0); + return null; + } + } + + /** + * + * @param csvFile + * @return ranking output + */ + public String featureRankingForClassification(String csvFile) { + try { + //csvToArff(true); + ConverterUtils.DataSource source = new ConverterUtils.DataSource(csvFile); + Instances data = source.getDataSet(); + AttributeSelection attrsel = new AttributeSelection(); + weka.attributeSelection.InfoGainAttributeEval eval = new weka.attributeSelection.InfoGainAttributeEval(); + + weka.attributeSelection.Ranker rank = new weka.attributeSelection.Ranker(); + rank.setOptions( + weka.core.Utils.splitOptions( + "-T -1 -N -1")); + attrsel.setEvaluator(eval); + attrsel.setSearch(rank); + + try { + attrsel.SelectAttributes(data); + } catch (UnsupportedAttributeTypeException e) { + //in case of id seen as string instead of nominal + weka.filters.unsupervised.attribute.StringToNominal f = new StringToNominal(); + String options = "-R first"; + f.setOptions(weka.core.Utils.splitOptions((options))); + f.setInputFormat(data); + data = Filter.useFilter(data, f); + try { + attrsel.SelectAttributes(data); + } catch (Exception ex) { + //in case where one or more features are seen as nominal + for (int i = 1; i < data.numAttributes() - 1; i++) { + if (data.attribute(i).isString()) { + f = new StringToNominal(); + options = "-R " + (i + 1); + f.setOptions(weka.core.Utils.splitOptions((options))); + f.setInputFormat(data); + data = Filter.useFilter(data, f); + } + } + attrsel.SelectAttributes(data); + } + } + return attrsel.toResultsString(); + } catch (Exception e) { + e.printStackTrace(); + return "Can't compute feature ranking"; + } + } + + public AttributeSelection featureRankingForRegression() { + try { + ConverterUtils.DataSource source = new ConverterUtils.DataSource(ARFFfile); + Instances data = source.getDataSet(); + AttributeSelection attrsel = new AttributeSelection(); + weka.attributeSelection.CfsSubsetEval eval = new weka.attributeSelection.CfsSubsetEval(); + weka.attributeSelection.GreedyStepwise rank = new weka.attributeSelection.GreedyStepwise(); + rank.setOptions( + weka.core.Utils.splitOptions( + "-R -T -1.7976931348623157E308 -N -1")); + + attrsel.setEvaluator(eval); + attrsel.setSearch(rank); + + attrsel.SelectAttributes(data); + return attrsel; + } catch (Exception e) { + e.printStackTrace(); + return null; + } + } + + public String featureRankingForRegression(String csvFile) { + try { + ConverterUtils.DataSource source = new ConverterUtils.DataSource(csvFile); + Instances data = source.getDataSet(); + AttributeSelection attrsel = new AttributeSelection(); + + weka.attributeSelection.ReliefFAttributeEval eval = new weka.attributeSelection.ReliefFAttributeEval(); + weka.attributeSelection.Ranker rank = new weka.attributeSelection.Ranker(); + rank.setOptions( + weka.core.Utils.splitOptions( + "-T -1.7976931348623157E308 -N -1")); + + attrsel.setEvaluator(eval); + attrsel.setSearch(rank); + + attrsel.SelectAttributes(data); + return attrsel.toResultsString(); + } catch (Exception e) { + e.printStackTrace(); + return null; + } + } + + public StringBuffer makePredictions(File testarff, String classifier) { + System.out.println("test file " + testarff.toString() + " on model " + classifier); + StringBuffer sb = new StringBuffer(); + + StringBuilder sbIds = new StringBuilder(); + try { + // load data + ConverterUtils.DataSource source = new ConverterUtils.DataSource(testarff.toString()); + Instances data = source.getDataSet(); + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + + Classifier model = (Classifier) weka.core.SerializationHelper.read(classifier); + + //get class names + HashMap<Integer, String> hmClassCorrespondance = new HashMap<>(); + for (String s : model.toString().split("\n")) { + if (s.startsWith("@attribute class")) { + s = s.split("\\{")[1].replace("}", ""); + String classes[] = s.split(","); + for (int i = 0; i < classes.length; i++) { + hmClassCorrespondance.put(i, classes[i]); + } + } + } + // label instances + sb.append("instance\tprediction\tprobability\n"); + for (int i = 0; i < data.numInstances(); i++) { + double prediction = model.classifyInstance(data.instance(i)); + double probability = model.distributionForInstance(data.instance(i))[(int) prediction]; + + sb.append(data.attribute(0).value(i)). + append("\t").append(hmClassCorrespondance.get((int) prediction)). + append("\t").append(probability). + append("\n"); + } + + } catch (Exception e) { + e.printStackTrace(); + } + + return sb; + } + + /** + * test a classifier + * + * @param testarff + * @param classifier file + * @param classification + * @return + */ + public Object testClassifierFromFileSource(File testarff, String classifier, boolean classification) { + if (Main.debug) { + System.out.println("test file " + testarff.toString() + " on model " + classifier); + } + StringBuffer sb = new StringBuffer(); + try { + // load data + ConverterUtils.DataSource source = new ConverterUtils.DataSource(testarff.toString()); + Instances data = source.getDataSet(); + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + //load model + Classifier model = (Classifier) weka.core.SerializationHelper.read(classifier); + + Evaluation eval = null; + PlainText pt = null; + + //if vote model, we need to simulate features index + if (model.getClass().toString().endsWith("Vote")) { + HashMap<String, String> hmNewDataFeaturesIndex = new HashMap<>(); + //get models inside vote + weka.classifiers.meta.Vote m = (weka.classifiers.meta.Vote) model; + Classifier[] modelClassifiers = m.getClassifiers(); + + //for each classifier, get features indexes of the signature + int classIndex = 0; + for (int i = 0; i < modelClassifiers.length; i++) { + //get model + FilteredClassifier filteredClassifier = (FilteredClassifier) modelClassifiers[i]; + //get features of the model in the right order + ArrayList<String> alFeatures = getFeaturesFromClassifier(filteredClassifier); + //get remove filter + Remove rf = (Remove) filteredClassifier.getFilter(); + //get index of the remove filter + String indexes[] = rf.getAttributeIndices().split(","); + //set class index + classIndex = Integer.valueOf(indexes[indexes.length - 1]); + //store features index with feature name + for (int j = 0; j < alFeatures.size(); j++) { + hmNewDataFeaturesIndex.put(indexes[j], alFeatures.get(j)); + } + } + //create new dataset + Instances newData = new Instances(data); + for (int i = newData.numAttributes() - 2; i >= 0; i--) { + newData.deleteAttributeAt(i); + } + //insert data for each instance + for (int j = 0; j < data.numInstances(); j++) { + //insert attribute "instance" for 1st instance only + if (j == 0) { + newData.insertAttributeAt(data.attribute(0), 0); + } + //insert instance name + try { + newData.instance(j).setValue(0, data.instance(j).stringValue(0)); + } catch (Exception e) { + //if not string + newData.instance(j).setValue(0, data.instance(j).value(0)); + } + for (int i = 1; i < classIndex - 1; i++) { + //if attribute do not exist in current data + if (hmNewDataFeaturesIndex.containsKey((i + 1) + "")) { + Attribute att = data.attribute(hmNewDataFeaturesIndex.get((i + 1) + "")); + if (j == 0) { + newData.insertAttributeAt(att, i); + } + newData.instance(j).setValue(i, data.instance(j).value(att)); + + } else { + //create an empty one with missing data + if (j == 0) { + newData.insertAttributeAt(new Attribute("Att_" + i), i); + } + } + } + } + data = newData; + } + +// CSVSaver csv = new CSVSaver(); +// csv.setInstances(data); +// csv.setFile(new File("E:\\cloud\\Projects\\bacteria\\test\\test.csv")); +// csv.writeBatch(); +// ArffSaver arff = new ArffSaver(); +// arff.setInstances(data); +// arff.setFile(new File("E:\\cloud\\Projects\\bacteria\\test\\test.arff")); +// arff.writeBatch(); + eval = new Evaluation(data); + pt = new PlainText(); + pt.setHeader(data); + pt.setBuffer(sb); + pt.setOutputDistribution(true); + pt.setAttributes("first"); + + eval.evaluateModel(model, data, pt, new Range("first,last"), true); + //return + if (classification) { + return new ClassificationResultsObject(eval, sb, data, model); + } else { + return new RegressionResultsObject(eval, sb, data, model); + } + + } catch (Exception e) { + if (Main.debug) { + e.printStackTrace(); + } + } + return null; + + } + + /** + * test a classifier + * + * @param testarff + * @param classifier file + * @param classification + * @return + */ + public Object testClassifierFromModel(Classifier model, String testFile, boolean classification, String out) { + + StringBuffer sb = new StringBuffer(); + try { + // load data + ConverterUtils.DataSource testSource = new ConverterUtils.DataSource(testFile.replace("csv", "arff")); + Instances testData = testSource.getDataSet(); + + testData = extractFeaturesFromDatasetBasedOnModel(model, testData); + testData.setClassIndex(testData.numAttributes() - 1); + //remove ID + //testData.deleteAttributeAt(0); + + // evaluate test dataset on the model + Evaluation eval = new Evaluation(testData); + PlainText pt = new PlainText(); + pt.setHeader(testData); + pt.setBuffer(sb); + eval.evaluateModel(model, testData, pt, new Range("first,last"), true); + +// //saving for debug +// SerializationHelper.write(Main.wd + "tmp." + out.replace("\t", "_") + ".model", model); +// ArffSaver arff = new ArffSaver(); +// arff.setInstances(testData); +// arff.setFile(new File(Main.wd + "tmp." + out.replace("\t", "_") + ".testdata.arff")); +// arff.writeBatch(); + //return + if (classification) { + return new ClassificationResultsObject(eval, sb, testData, model); + } else { + return new RegressionResultsObject(eval, sb, testData, model); + } + + } catch (Exception e) { + if (Main.debug) { + System.err.println(e.getMessage()); + e.printStackTrace(); + } + + } + return null; + + } + +// /** +// * test a classifier +// * +// * @param testarff +// * @param classifier file +// * @param classification +// * @return +// */ +// public Object testClassifier(String classifier, String classifier_options, String features, String trainFile, String testFile, boolean classification) { +// +// StringBuffer sb = new StringBuffer(); +// try { +// // load data +// ConverterUtils.DataSource trainSource = new ConverterUtils.DataSource(trainFile.replace("csv", "arff")); +// ConverterUtils.DataSource testSource = new ConverterUtils.DataSource(testFile.replace("csv", "arff")); +// +// Instances trainData = trainSource.getDataSet(); +// trainData.setClassIndex(trainData.numAttributes() - 1); +// Instances testData = testSource.getDataSet(); +// testData.setClassIndex(testData.numAttributes() - 1); +// +// //load model +// String configuration = classifier + " " + classifier_options; +// if (!classifier.contains("meta.Vote")) { +// //keep only attributesToUse. ID is still here +// String options = ""; +// weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); +// options = "-V -R " + features; +// r.setOptions(weka.core.Utils.splitOptions((options))); +// r.setInputFormat(trainData); +// trainData = Filter.useFilter(trainData, r); +// trainData.setClassIndex(trainData.numAttributes() - 1); +// testData = Filter.useFilter(testData, r); +// testData.setClassIndex(testData.numAttributes() - 1); +// +// //create weka configuration string +// //remove ID from classification +// String filterID = "" +// + "weka.classifiers.meta.FilteredClassifier " +// + "-F \"weka.filters.unsupervised.attribute.Remove -R 1\" " +// + "-W weka.classifiers."; +// +// //create config +// configuration = filterID + "" + classifier + " -- " + classifier_options; +// +// } +// +// //create model +// final String config[] = Utils.splitOptions(configuration); +// String classname = config[0]; +// config[0] = ""; +// Classifier model = null; +// +// model = (Classifier) Utils.forName(Classifier.class, classname, config); +// +// // evaluate dataset on the model +// Evaluation eval = new Evaluation(testData); +// PlainText pt = new PlainText(); +// pt.setBuffer(sb); +// +// //1 fold cross validation +// eval.crossValidateModel(model, trainData, 10, new Random(1), pt, new Range("first,last"), true); +// +// eval.evaluateModel(model, testData); +// model.buildClassifier(trainData); +// +// //return +// if (classification) { +// return new ClassificationResultsObject(eval, sb, testData, model); +// } else { +// return new RegressionResultsObject(eval, sb, testData, model); +// } +// +// } catch (Exception e) { +// if (Main.debug) { +// e.printStackTrace(); +// } +// } +// return null; +// +// } + public void setCSVFile(File file) { + CSVFile = file; + } + + public void setARFFfile(String file) { + ARFFfile = file; + } + + /** + * set local variable mydata with ARFFfile + */ + public void setDataFromArff() { + // load data + try { + ConverterUtils.DataSource source = new ConverterUtils.DataSource(ARFFfile); + myData = source.getDataSet(); + } catch (Exception e) { + e.printStackTrace(); + } + } + + /** + * set local variable mydata with CSVFfile. VERY DANGEROUS WITHOUT LOADER + * THAT SET THE SEPARATOR Need to load as csvToArff() + */ + public void setDataFromCSV() { +// // load data +// try { +// //implement loader +// ConverterUtils.DataSource source = new ConverterUtils.DataSource(CSVFile.toString()); +// myData = source.getDataSet(); +// } catch (Exception e) { +// e.printStackTrace(); +// } + } + + /** + * TODO: find another way to detect classification vs regression. A user + * could use 1 and 0 instead of true false; + * + * @return + */ + public boolean isClassification() { + try { + Double.parseDouble(myData.instance(0).toString(myData.numAttributes() - 1)); + //Double.parseDouble(myData.attribute(myData.numAttributes() - 1).enumerateValues().nextElement().toString()); + return false; + } catch (Exception e) { + return true; + } + + } + + public ArrayList<String> getFeaturesFromClassifier(String classifier) { + //HashMap<String, String> hm = new HashMap<>(); + ArrayList<String> al = new ArrayList<>(); + try { + //load model + Classifier model = (Classifier) weka.core.SerializationHelper.read(classifier); + if (model.getClass().toString().contains("Vote")) { + String tab[] = model.toString().split("\n"); + ArrayList<String> alModels = new ArrayList<>(); + int i = 1; + while (!tab[i].startsWith("using")) { + alModels.add(tab[i].replaceAll("^\t", "")); + i++; + } + int j = 0; + for (i = i; i < tab.length; i++) { + if (tab[i].startsWith("Filtered Header")) { + al.add("Model: " + alModels.get(j)); + j++; + i++; + while (i < tab.length && !tab[i].startsWith("Filtered Header")) { + if (tab[i].startsWith("@attribute")) { + al.add(tab[i].replace("@attribute ", "") + .replaceAll(" \\w+$", "") + .replaceAll(" \\{.*$", "")); + } + i++; + } + i--; + } + } + + } else { + for (String s : model.toString().split("\n")) { + if (s.startsWith("@attribute")) { + if (s.split(" ")[1].equals("class")) { + al.add("class"); + } else { + al.add(s.replace("@attribute ", "") + .replaceAll(" \\w+$", "") + .replaceAll(" \\{.*$", "")); + } + } + } + } + } catch (Exception e) { + if (Main.debug) { + e.printStackTrace(); + } + } + return al; + } + + /** + * get classes from classifier + * + * @param classifier + * @param toArrayList + * @return + */ + public ArrayList<String> getClassesFromClassifier(String classifier, boolean toArrayList) { + ArrayList<String> al = new ArrayList<>(); + String classes = ""; + try { + //load model + Classifier model = (Classifier) weka.core.SerializationHelper.read(classifier); + for (String s : model.toString().split("\n")) { + if (s.startsWith("@attribute class")) { + classes = s; + } + } + } catch (Exception e) { + e.printStackTrace(); + } + if (toArrayList) { + classes = classes.replace("@attribute class ", "").replace("{", "").replace("}", ""); + for (String s : classes.split(",")) { + al.add(s); + } + return al; + } else { + al.add(classes); + return al; + } + } + + public ArrayList<String> getFeaturesFromClassifier(Classifier model) { + //HashMap<String, String> hm = new HashMap<>(); + ArrayList<String> al = new ArrayList<>(); + try { + //load model + for (String s : model.toString().split("\n")) { + if (s.startsWith("@attribute '") && s.endsWith("numeric")) { + s = s.replace("@attribute '", "").trim(); + s = s.substring(0, s.lastIndexOf("'")); + al.add(s); + } else if (s.startsWith("@attribute '")) { + s = s.replace("@attribute '", "").trim(); + s = s.substring(0, s.lastIndexOf("' {")); + al.add(s); + } else if (s.startsWith("@attribute")) { + //hm.put(s.split(" ")[1], s.split(" ")[2]); //feature_name, feature_type + s = s.replace("@attribute ", "").trim(); + try { + s = s.substring(0, s.lastIndexOf(" {")); + } catch (Exception e) { + s = s.substring(0, s.lastIndexOf(" ")); + } + + al.add(s); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + return al; + } + + public HashMap<String, String> getFullFeaturesFromClassifier(String modelFile) { + //HashMap<String, String> hm = new HashMap<>(); + HashMap<String, String> hm = new HashMap<>(); + try { + Classifier model = (Classifier) weka.core.SerializationHelper.read(modelFile); + + //load model + for (String s : model.toString().split("\n")) { + if (s.startsWith("@attribute '")) { + s = s.replace("@attribute '", "").trim(); + String attributeName = s.substring(0, s.lastIndexOf("' {")).trim(); + String attributeType = s.substring(s.indexOf("' {")).trim(); + hm.put("'" + attributeName + "'", attributeType); + } else if (s.startsWith("@attribute")) { + //hm.put(s.split(" ")[1], s.split(" ")[2]); //feature_name, feature_type + s = s.replace("@attribute ", "").trim(); + try { + String attributeName = s.substring(0, s.lastIndexOf(" {")).trim(); + String attributeType = s.substring(s.indexOf(" {")).trim(); + hm.put(attributeName, attributeType); + } catch (Exception e) { + String attributeName = s.substring(0, s.lastIndexOf(" ")).trim(); + String attributeType = s.substring(s.indexOf(" ")).trim(); + hm.put(attributeName, attributeType); + } + } + } + } catch (Exception e) { + e.printStackTrace(); + } + return hm; + } + + public void saveFilteredDataToArff(String attributesToUse, boolean classification, String outfile) { + try { + // load data + Instances data = myData; + + //set class index + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + + //keep only attributesToUse. ID is still here + String options = ""; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + options = "-V -R " + attributesToUse; + r.setOptions(weka.core.Utils.splitOptions((options))); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + data.setClassIndex(data.numAttributes() - 1); + + //save csv to arff + ArffSaver arff = new ArffSaver(); + arff.setInstances(data); + arff.setFile(new File(outfile)); + arff.writeBatch(); + + } catch (Exception e) { + e.printStackTrace(); + + } + } + + public void saveFilteredDataToCSV(String attributesToUse, boolean classification, String outfile) { + try { + // load data + Instances data = myData; + + //set class index + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + + //keep only attributesToUse. ID is still here + String options = ""; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + options = "-V -R " + attributesToUse; + r.setOptions(weka.core.Utils.splitOptions((options))); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + data.setClassIndex(data.numAttributes() - 1); + + //save csv to arff + CSVSaver csv = new CSVSaver(); + csv.setInstances(data); + csv.setFile(new File(outfile)); + csv.writeBatch(); + + } catch (Exception e) { + e.printStackTrace(); + + } + } + + /** + * perform sampling on the dataset + * + * @param fullDatasetFile + * @param dataToTrainFile + * @param dataToTestFile + */ + public void sampling(String fullDatasetFile, String dataToTrainFile, String dataToTestFile, boolean classification, String range) { + System.out.println("# sampling statistics"); + boolean useRange = !range.isEmpty(); + setCSVFile(new File(fullDatasetFile)); + csvToArff(classification); + setDataFromArff(); + // load data + Instances data = myData; + //set class index + if (data.classIndex() == -1) { + data.setClassIndex(data.numAttributes() - 1); + } + //calculate seed + Random rand = new Random(); + int seed = 1 + rand.nextInt(1000); + + try { + Instances TestData = null; + Instances TrainData = null; + + //create train set + if (useRange) { + weka.filters.unsupervised.instance.RemoveRange rr = new RemoveRange(); + String options = "-V -R " + range; + rr.setOptions(weka.core.Utils.splitOptions((options))); + rr.setInputFormat(data); + TrainData = Filter.useFilter(data, rr); + } else { + weka.filters.supervised.instance.StratifiedRemoveFolds srf = new StratifiedRemoveFolds(); + String options = "-S " + seed + " -V -N " + Main.samplingFold + " -F 1"; //split dataset in 3 and inverse the selection + srf.setOptions(weka.core.Utils.splitOptions((options))); + srf.setInputFormat(data); + TrainData = Filter.useFilter(data, srf); + } + System.out.println("Total number of instances in train data file: " + TrainData.numInstances()); + if (classification) { + Enumeration e = TrainData.attribute(data.classIndex()).enumerateValues(); + int cpt = 0; + while (e.hasMoreElements()) { + String className = (String) e.nextElement(); + double n = TrainData.attributeStats(data.classIndex()).nominalWeights[cpt]; + System.out.println("class " + className + ": " + n + " (" + df.format(n / (double) TrainData.numInstances() * 100) + "%)"); + cpt++; + } + } + + //create test set + if (useRange) { + weka.filters.unsupervised.instance.RemoveRange rr = new RemoveRange(); + String options = "-R " + range; + rr.setOptions(weka.core.Utils.splitOptions((options))); + rr.setInputFormat(data); + TestData = Filter.useFilter(data, rr); + } else { + weka.filters.supervised.instance.StratifiedRemoveFolds srf = new StratifiedRemoveFolds(); + String options = "-S " + seed + " -N " + Main.samplingFold + " -F 1"; //split dataset in 3 + srf.setOptions(weka.core.Utils.splitOptions((options))); + srf.setInputFormat(data); + TestData = Filter.useFilter(data, srf); + } + + System.out.println("Total number of instances in test data file: " + TestData.numInstances()); + if (classification) { + Enumeration e = TestData.attribute(data.classIndex()).enumerateValues(); + int cpt = 0; + while (e.hasMoreElements()) { + String className = (String) e.nextElement(); + double n = TestData.attributeStats(data.classIndex()).nominalWeights[cpt]; + System.out.println("class " + className + ": " + n + " (" + df.format(n / (double) TestData.numInstances() * 100) + "%)"); + cpt++; + } + } + //convert strings to nominal + TestData = convertStringsToNominal(TestData); + TrainData = convertStringsToNominal(TrainData); + //save to csv and arff + if (!useRange) { + ///test set to csv + CSVSaver csv = new CSVSaver(); + csv.setInstances(TestData); + csv.setFile(new File(dataToTestFile)); + csv.writeBatch(); + ///train set to csv + csv = new CSVSaver(); + csv.setInstances(TrainData); + csv.setFile(new File(dataToTrainFile)); + csv.writeBatch(); + } + + //test set to arff + ArffSaver arff = new ArffSaver(); + arff.setInstances(TestData); + arff.setFile(new File(dataToTestFile.replace("csv", "arff"))); + arff.writeBatch(); + + //train set to arff + arff = new ArffSaver(); + arff.setInstances(TrainData); + arff.setFile(new File(dataToTrainFile.replace("csv", "arff"))); + arff.writeBatch(); + + } catch (Exception e) { + e.printStackTrace(); + } + } + + /** + * in case of combined models using filters for selecting specific filters + * based on indexes + * + * @param featureSelectedData + * @param arffFile + * @param outfile + */ + public void extractFeaturesFromArffFileBasedOnSelectedFeatures(Instances featureSelectedData, Instances arffFile, + String outfile) { + //get attribute names of feature selection data + ArrayList<String> aFeaturesSelected = new ArrayList<>(); + Enumeration<Attribute> enumerate = featureSelectedData.enumerateAttributes(); + while (enumerate.hasMoreElements()) { + Attribute a = enumerate.nextElement(); + aFeaturesSelected.add(a.name()); + } + HashMap<String, String> hmModelFeatures = new HashMap<>();//indexed hashed features + for (String f : aFeaturesSelected) { + hmModelFeatures.put(f, f); + } + + Instances data = arffFile; + //get index of model features in the arff file + String indexesToKeep = ""; + Enumeration e = data.enumerateAttributes(); + while (e.hasMoreElements()) { + Attribute a = (Attribute) e.nextElement(); + if (hmModelFeatures.containsKey(a.name())) { + indexesToKeep += (a.index() + 1) + ","; + } + } + indexesToKeep += myData.numAttributes(); + try { + //apply remove filter + String[] opts = new String[3]; + opts[0] = "-V"; + opts[1] = "-R"; + opts[2] = indexesToKeep; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + r.setOptions(opts); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + + //reorder + String order = ","; + for (String featureOrdered : aFeaturesSelected) { + int featureNotOrdered = data.attribute(featureOrdered).index() + 1; + order += featureNotOrdered + ","; + } + order += data.attribute("class").index() + 1; + + Reorder reorder = new Reorder(); + reorder.setAttributeIndices(order); + reorder.setInputFormat(data); + data = Filter.useFilter(data, reorder); + + //save + ArffSaver arff = new ArffSaver(); + arff.setInstances(data); + arff.setFile(new File(outfile)); + arff.writeBatch(); + } catch (Exception ex) { + ex.printStackTrace(); + } + + } + + /** + * extract features from test file while keeping arff compatibility + * + * @param modelFile + * @param arffTestFile + * @param outfile + */ + public void extractFeaturesFromTestFileBasedOnModel(String modelFile, Instances arffTestFile, String outfile) { + Instances data = arffTestFile; + //get model features + ArrayList<String> alModelFeatures = getFeaturesFromClassifier(modelFile); //features in the right order + HashMap<String, String> hmModelFeatures = new HashMap<>();//indexed hashed features + for (String f : alModelFeatures) { + hmModelFeatures.put(f, f); + } + + //get index of model features in the arff test file + String indexesToKeep = "1"; //keep ID + Enumeration e = myData.enumerateAttributes(); + while (e.hasMoreElements()) { + Attribute a = (Attribute) e.nextElement(); + if (hmModelFeatures.containsKey(a.name())) { + indexesToKeep += "," + (a.index() + 1); + } + } + if (!hmModelFeatures.containsKey("class")) { + indexesToKeep += "," + myData.numAttributes(); + } + try { + //apply remove filter + String[] opts = new String[3]; + opts[0] = "-V"; + opts[1] = "-R"; + opts[2] = indexesToKeep; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + r.setOptions(opts); + r.setInputFormat(data); + data = Filter.useFilter(data, r); + + //reorder + String order = "1,"; + for (String featureOrdered : alModelFeatures) { + int featureNotOrdered = data.attribute(featureOrdered).index() + 1; + order += featureNotOrdered + ","; + } + Reorder reorder = new Reorder(); + reorder.setAttributeIndices(order); + reorder.setInputFormat(data); + data = Filter.useFilter(data, reorder); + + //save to arff + ArffSaver arff = new ArffSaver(); + arff.setInstances(data); + arff.setFile(new File(outfile + ".tmp")); + arff.writeBatch(); + //change non-numeric attributes to ensure compatibility + String classString = ""; + Classifier model = (Classifier) weka.core.SerializationHelper.read(modelFile); + HashMap<String, String> hm = new HashMap<>(); + for (String s : model.toString().split("\n")) { + if (s.startsWith("@attribute") && !s.endsWith(" numeric")) { + String attribute = s.replace("@attribute ", ""); + attribute = attribute.substring(0, attribute.indexOf("{")); + hm.put(attribute, s); + } + } + try { + BufferedReader br = new BufferedReader(new FileReader(outfile + ".tmp")); + PrintWriter pw = new PrintWriter(new FileWriter(outfile)); + while (br.ready()) { + String line = br.readLine(); + if (line.startsWith("@attribute") && !line.endsWith(" numeric") + && !line.contains("@attribute Instance") + && !line.contains("@attribute " + Main.mergingID)) { + String attribute = line.replace("@attribute ", "").trim(); + attribute = attribute.substring(0, attribute.indexOf("{")); + if (hm.get(attribute) != null) { + pw.println(hm.get(attribute)); + } else { + pw.println(line); + } + } else { + pw.println(line); + } + } + pw.close(); + br.close(); + new File(outfile + ".tmp").delete(); + } catch (Exception ex) { + ex.printStackTrace(); + } + + } catch (Exception ex) { + ex.printStackTrace(); + } + + } + + /** + * extract features from dataset + * + * @param modelFile + * @param dataset + * @param outfile + */ + public Instances extractFeaturesFromDatasetBasedOnModel(Classifier model, Instances dataset) { + //get model features + ArrayList<String> alModelFeatures = getFeaturesFromClassifier(model); //features in the right order + HashMap<String, String> hmModelFeatures = new HashMap<>();//indexed hashed features + for (String f : alModelFeatures) { + hmModelFeatures.put(f, f); + } + //get index of model features in the arff test file + String indexesToKeep = "1"; //keep ID + Enumeration e = dataset.enumerateAttributes(); + while (e.hasMoreElements()) { + Attribute a = (Attribute) e.nextElement(); + if (hmModelFeatures.containsKey(a.name())) { + indexesToKeep += "," + (a.index() + 1); + } + } + if (!hmModelFeatures.containsKey("class")) { + indexesToKeep += "," + dataset.numAttributes(); + } + Instances filteredDataset = null; + try { + //apply remove filter + String[] opts = new String[3]; + opts[0] = "-V"; + opts[1] = "-R"; + opts[2] = indexesToKeep; + weka.filters.unsupervised.attribute.Remove r = new weka.filters.unsupervised.attribute.Remove(); + r.setOptions(opts); + r.setInputFormat(dataset); + filteredDataset = Filter.useFilter(dataset, r); + + //reorder + String order = "1,"; + for (String featureOrdered : alModelFeatures) { + int featureNotOrdered = filteredDataset.attribute(featureOrdered).index() + 1; + order += featureNotOrdered + ","; + } + Reorder reorder = new Reorder(); + reorder.setAttributeIndices(order); + reorder.setInputFormat(filteredDataset); + filteredDataset = Filter.useFilter(filteredDataset, reorder); + } catch (Exception ex) { + ex.printStackTrace(); + } + return filteredDataset; + + } + + /** + * ensure compatibility of non-numeric attributes between 2 arff files + * + * @param ArffFileWithReferenceHeader contains the header you want + * @param ArffFileToModify file fo modify header + */ + public void makeCompatibleARFFheaders(String ArffFileWithReferenceHeader, String ArffFileToModify) { + HashMap<String, String> hm = new HashMap<>(); + try { + // read source. Get non-numeric attributes + BufferedReader br = new BufferedReader(new FileReader(ArffFileWithReferenceHeader)); + String line = ""; + while (br.ready()) { + line = br.readLine(); + if (line.startsWith("@attribute") && !line.trim().endsWith("numeric")) { + String attribute = line.replace("@attribute ", ""); + attribute = attribute.substring(0, attribute.indexOf("{")); + hm.put(attribute, line); + } + } + br.close(); + //read destination + br = new BufferedReader(new FileReader(ArffFileToModify)); + PrintWriter pw = new PrintWriter(new FileWriter(ArffFileToModify + "_tmp")); + while (br.ready()) { + line = br.readLine(); + if (line.startsWith("@attribute") && !line.endsWith(" numeric")) { + String attribute = line.replace("@attribute ", "").trim(); + attribute = attribute.substring(0, attribute.indexOf("{")); + pw.println(hm.get(attribute)); + } else { + pw.println(line); + } + } + pw.close(); + br.close(); + new File(ArffFileToModify).delete(); + new File(ArffFileToModify + "_tmp").renameTo(new File(ArffFileToModify)); + } catch (Exception ex) { + ex.printStackTrace(); + } + } + + public Instances convertStringsToNominal(Instances data) { + try { + Enumeration en = data.enumerateAttributes(); + while (en.hasMoreElements()) { + Attribute a = (Attribute) en.nextElement(); + if (a.isString()) { + StringToNominal stn = new StringToNominal(); + stn.setOptions(new String[]{"-R", (a.index() + 1) + ""}); + stn.setInputFormat(data); + data = Filter.useFilter(data, stn); + } + } + } catch (Exception e) { + e.printStackTrace(); + } + + return data; + + } + + public static class RegressionResultsObject { + + public String CC;//correlation coefficient + public String MAE;//Mean absolute error + public String RMSE;//Root mean squared error + public String RAE;//Relative absolute error + public String RRSE;//Root relative squared error + public StringBuilder predictions = new StringBuilder(); + public StringBuilder features = new StringBuilder(); + public int numberOfFeatures = 0; //number of features + public Classifier model; + public String featuresRankingResults; + public Evaluation evaluation; + public Instances dataset; + + public RegressionResultsObject() { + } + + private RegressionResultsObject(Evaluation eval, StringBuffer sb, Instances data, Classifier classifier) { + dataset = data; + evaluation = eval; + + if (!sb.toString().isEmpty()) { + try { + parsePredictions(sb); + } catch (Exception e) { + } + } + //getFeatures(data); + getFeatures(classifier); + model = classifier; + + //measures + try { + CC = Utils.doubleToString(eval.correlationCoefficient(), 12, 4).trim(); + MAE = Utils.doubleToString(eval.meanAbsoluteError(), 12, 4).trim(); + RMSE = Utils.doubleToString(eval.rootMeanSquaredError(), 12, 4).trim(); + RAE = Utils.doubleToString(eval.relativeAbsoluteError() / 100, 12, 4).trim(); + RRSE = Utils.doubleToString(eval.rootRelativeSquaredError() / 100, 12, 4).trim(); + } catch (Exception exception) { + exception.printStackTrace(); + } + + } + + @Override + public String toString() { + return (CC + "\t" + + MAE + "\t" + + RMSE + "\t" + + RAE + "\t" + + RRSE); + } + + public String toStringDetails() { + return ("[score_training] CC: " + CC + "\n" + + "[score_training] MAE: " + MAE + "\n" + + "[score_training] RMSE: " + RMSE + "\n" + + "[score_training] RAE: " + RAE + "\n" + + "[score_training] RRSE: " + RRSE); + } + + private void parsePredictions(StringBuffer sb) { + String lines[] = sb.toString().split("\n"); + + for (int i = 1; i < lines.length; i++) { + String p = lines[i].replaceAll(" +", "\t"); + String tab[] = p.split("\t"); + String inst = tab[1]; + String actual = tab[2]; + String predicted = tab[3]; + String error = tab[4]; + String ID = tab[5]; + + predictions.append(ID.replace("(", "").replace(")", "") + "\t" + actual + "\t" + predicted + "\t" + error + "\n"); + } + + } + + private void getFeatures(Instances data) { + for (int i = 1; i < data.numAttributes() - 1; i++) { + numberOfFeatures++; + features.append(data.attribute(i).name() + "\n"); + } + } + + private void getFeatures(Classifier model) { + boolean voteClassifier = model.toString().split("\n")[0].contains("Vote"); + TreeMap<String, Integer> tm = new TreeMap<>(); + for (String s : model.toString().split("\n")) { + if (s.startsWith("@attribute")) { + //hm.put(s.split(" ")[1], s.split(" ")[2]); //feature_name, feature_type + String name = s.split(" ")[1]; + if (tm.containsKey(name)) { + int i = tm.get(name); + i++; + tm.put(name, i); + } else { + tm.put(s.split(" ")[1], 1); + } + } + + } + for (String f : tm.keySet()) { + if (!f.equals("class")) { + numberOfFeatures++; + if (voteClassifier) { + features.append(f + "\t" + tm.get(f) + "\n"); + } else { + features.append(f + "\n"); + } + } + } + } + + public String getFeatureRankingResults() { + String[] s = featuresRankingResults.split("\n"); + StringBuilder toreturn = new StringBuilder(); + int index = 0; + while (!s[index].equals("Ranked attributes:")) { + index++; + } + index++; + while (!s[index].startsWith("Selected")) { + if (!s[index].contains(Main.mergingID) && !s[index].isEmpty()) { + String out = s[index].substring(1); //remove first space + out = out.replaceAll(" +", "\t"); + out = out.split("\t")[0] + "\t" + out.split("\t")[2]; + toreturn.append(out + "\n"); + } + index++; + } + return toreturn.toString(); + } + + } + + public static class ClassificationResultsObject { + + public String ACC; + public String MCC; + public String TNR;//specificity + public String AUC; + public String TPR;//sensitivity + public String Fscore; + public double TP; + public double FN; + public double TN; + public double FP; + public String kappa; + public String MAE; + public StringBuilder predictions = new StringBuilder(); + public StringBuilder features = new StringBuilder(); + public int numberOfFeatures = 0; //number of features + public Classifier model; + public String featuresRankingResults; + public Evaluation evaluation; + public Instances dataset; + public String FPR; + public String recall; + public String FNR; + public String AUPRC; + public String precision; + public String FDR; + public String matrix; + public String pAUC; + public String BCR; + public String BER; + public String classes = ""; + + private ClassificationResultsObject() { + } + + public void getPredictions() { + System.out.println("instance\tactual\tpredicted\terror\tprobability"); + System.out.println(predictions); + } + + private ClassificationResultsObject(Evaluation eval, StringBuffer sb, Instances data, Classifier classifier) { + dataset = data; + evaluation = eval; + model = classifier; + getFeaturesAndClasses(classifier); + + if (!sb.toString().isEmpty()) { + try { + parsePredictions(sb); + } catch (Exception e) { + try { + parsePredictions2(sb); + } catch (Exception e2) { + if (Main.debug) { + e.printStackTrace(); + e2.printStackTrace(); + } + } + + } + } + + //measures + ACC = Utils.doubleToString(eval.pctCorrect() / 100, 12, 4).trim(); + MCC = Utils.doubleToString(eval.weightedMatthewsCorrelation(), 12, 4).trim(); + TPR = Utils.doubleToString(eval.weightedTruePositiveRate(), 12, 4).trim();//sensitivity + TNR = Utils.doubleToString(eval.weightedTrueNegativeRate(), 12, 4).trim();//specificity + FPR = Utils.doubleToString(eval.weightedFalsePositiveRate(), 12, 4).trim(); + FNR = Utils.doubleToString(eval.weightedFalseNegativeRate(), 12, 4).trim(); + AUC = Utils.doubleToString(eval.weightedAreaUnderROC(), 12, 4).trim(); + AUPRC = Utils.doubleToString(eval.weightedAreaUnderPRC(), 12, 4).trim(); + Fscore = Utils.doubleToString(eval.weightedFMeasure(), 12, 4).trim(); + kappa = Utils.doubleToString(eval.kappa(), 12, 4).trim(); + MAE = Utils.doubleToString(eval.meanAbsoluteError(), 12, 4).trim(); + precision = Utils.doubleToString(eval.weightedPrecision(), 12, 4).trim(); + FDR = Utils.doubleToString(1 - eval.weightedPrecision(), 12, 4).trim(); + recall = Utils.doubleToString(eval.weightedRecall(), 12, 4).trim(); + + double[][] m = eval.confusionMatrix(); + matrix = ""; + for (int i = 0; i < m.length; i++) { + matrix += "["; + for (int j = 0; j < m.length; j++) { + matrix += (int) m[i][j] + " "; + } + matrix = matrix.substring(0, matrix.length() - 1) + "] "; + } + + try { + //pAUC + int numberOfClasses = m.length; + double pAUCs = 0; + for (int i = 0; i < numberOfClasses; i++) { + Instances rocCurve = new ThresholdCurve().getCurve(eval.predictions(), i); + pAUCs += getPartialROCArea(rocCurve, Main.pAUC_lower, Main.pAUC_upper); + } + pAUC = Utils.doubleToString((pAUCs / (double) numberOfClasses), 12, 4).trim(); + } catch (Exception e) { + pAUC = "NaN"; + } + + //balanced classification rate + BCR = Utils.doubleToString( + ((eval.weightedTruePositiveRate() + eval.weightedTrueNegativeRate()) / m.length), 12, 4).trim(); + //BER (Balanced error rate) + //https://arxiv.org/pdf/1207.3809.pdf + BER = Utils.doubleToString(0.5 * ((eval.weightedFalsePositiveRate()) + (eval.weightedFalseNegativeRate())), 12, 4).trim(); +// BER = Utils.doubleToString( +// (1 - (eval.weightedTruePositiveRate() + eval.weightedTrueNegativeRate()) / m.length), 12, 4).trim(); + } + + @Override + public String toString() { + return (ACC + "\t" + + AUC + "\t" + + AUPRC + "\t" + + TPR + "\t" //SEN == recall + + TNR + "\t" //SPE + + MCC + "\t" + + MAE + "\t" + + BER + "\t" + + FPR + "\t" + + FNR + "\t" + + precision + "\t" //= PPV (positive predictive value) + + FDR + "\t" + + Fscore + "\t" + + kappa + "\t" + + matrix); + } + + public String toStringShort() { + return (ACC + "\t" + + AUC + "\t" + + AUPRC + "\t" + + TPR + "\t" //SEN == recall + + TNR + "\t" //SPE + + MCC + "\t" + + MAE + "\t" + + BER); + } + + public String toStringDetails() { + return ("[score_training] ACC: " + ACC + "\n" + + "[score_training] AUC: " + AUC + "\n" + + "[score_training] AUPRC: " + AUPRC + "\n" + + "[score_training] TPR: " + TPR + "\n" + + "[score_training] TNR: " + TNR + "\n" + + "[score_training] MCC: " + MCC + "\n" + + "[score_training] MAE: " + MAE + "\n" + + "[score_training] BER: " + BER + "\n" + + "[score_training] FPR: " + FPR + "\n" + + "[score_training] FNR: " + FNR + "\n" + + "[score_training] PPV: " + precision + "\n" + + "[score_training] FDR: " + FDR + "\n" + + "[score_training] recall: " + recall + "\n" + + "[score_training] Fscore: " + Fscore + "\n" + + "[score_training] kappa: " + kappa + "\n" + + "[score_training] matrix: " + matrix // + "pAUC[" + Main.pAUC_lower + "-" + Main.pAUC_upper + pAUC + "]: \n" + ); + } + + public String toStringDetailsTesting() { + return ("[score_testing] ACC: " + ACC + "\n" + + "[score_testing] AUC: " + AUC + "\n" + + "[score_testing] AUPRC: " + AUPRC + "\n" + + "[score_testing] TPR: " + TPR + "\n" + + "[score_testing] TNR: " + TNR + "\n" + + "[score_testing] MCC: " + MCC + "\n" + + "[score_testing] MAE: " + MAE + "\n" + + "[score_testing] BER: " + BER + "\n" + + "[score_testing] FPR: " + FPR + "\n" + + "[score_testing] FNR: " + FNR + "\n" + + "[score_testing] PPV: " + precision + "\n" + + "[score_testing] FDR: " + FDR + "\n" + + "[score_testing] recall: " + recall + "\n" + + "[score_testing] Fscore: " + Fscore + "\n" + + "[score_testing] kappa: " + kappa + "\n" + + "[score_testing] matrix: " + matrix // + "pAUC[" + Main.pAUC_lower + "-" + Main.pAUC_upper + pAUC + "]: \n" + ); + } + + /** + * first attempt to retrieve predictions + * + * @param sb + */ + private void parsePredictions(StringBuffer sb) { + String lines[] = sb.toString().split("\n"); + + //get predictions + for (int i = 0; i < lines.length; i++) { + String p = lines[i].replaceAll(" +", "\t"); + if (!p.trim().startsWith("inst#")) { + String tab[] = p.split("\t"); + String inst = tab[1]; + try { + String actual = tab[2].split(":")[1]; + String predicted = ""; + try { + predicted = tab[3].split(":")[1]; + } catch (Exception e) { + predicted = tab[3]; + } + String error; + String prob; + String ID; + if (tab.length == 6) { + if (actual.contains("?")) { + error = "?"; + } else { + error = "No"; + } + prob = tab[4]; + ID = tab[5]; + if (prob.contains("?")) { + error = "?"; + } + } else { + error = "Yes"; + prob = tab[5]; + ID = tab[6]; + } + predictions.append(ID.replace("(", "").replace(")", "") + "\t" + actual + "\t" + predicted + "\t" + error + "\t" + prob.replace(",", "\t") + "\n"); + } catch (Exception e) { + //e.printStackTrace(); + } + } + } + + } + + // in case of different format + private void parsePredictions2(StringBuffer sb) { + String lines[] = sb.toString().split("\n"); + + //skip header + int a = 0; + if (lines[0].trim().startsWith("inst#")) { + a = 1; + } + + //get predictions + for (int i = a; i < lines.length; i++) { + + String p = lines[i].trim().replaceAll(" +", "\t"); + String tab[] = p.split("\t"); + try { + String inst = tab[0]; + String actual = tab[1].split(":")[1]; + String predicted = tab[2].split(":")[1]; + String error; + String prob; + String ID; + if (tab.length == 4) { + if (actual.contains("?")) { + error = "?"; + } else { + error = "No"; + } + prob = tab[3]; + try { + ID = tab[4]; + } catch (Exception e) { + ID = ""; + } + } else { + error = "Yes"; + prob = tab[4]; + try { + ID = tab[5]; + } catch (Exception e) { + ID = ""; + } + } + predictions.append(ID.replace("(", "").replace(")", "") + "\t" + + actual + "\t" + predicted + "\t" + error + "\t" + prob.replace(",", "\t") + "\n"); + } catch (Exception e) { + //e.printStackTrace(); + } + + } + + } + + private void getFeaturesAndClasses(Classifier model) { + boolean voteClassifier = model.toString().split("\n")[0].contains("Vote"); + TreeMap<String, Integer> tm = new TreeMap<>(); + for (String s : model.toString().split("\n")) { + if (s.startsWith("@attribute")) { + //hm.put(s.split(" ")[1], s.split(" ")[2]); //feature_name, feature_type + String name = s.split(" ")[1]; + if (tm.containsKey(name)) { + int i = tm.get(name); + i++; + tm.put(name, i); + } else { + tm.put(s.split(" ")[1], 1); + } + } + if (s.startsWith("@attribute class ")) { + classes = s.replace("@attribute class {", "").replace("}", "").replace(",", "\t"); + } + + } + for (String f : tm.keySet()) { + if (!f.equals("class")) { + numberOfFeatures++; + if (voteClassifier) { + features.append(f + "\t" + tm.get(f) + "\n"); + } else { + features.append(f + "\n"); + } + } + } + } + + public String getFeatureRankingResults() { + String[] s = featuresRankingResults.split("\n"); + StringBuilder toreturn = new StringBuilder(); + int index = 0; + while (!s[index].equals("Ranked attributes:")) { + index++; + } + index++; + while (!s[index].startsWith("Selected")) { + if (!s[index].contains(Main.mergingID) && !s[index].isEmpty()) { + //String out = s[index].substring(1); //remove first space + String out = s[index].trim(); //remove first space + out = out.replaceAll(" +", "\t"); + String output = out.split("\t")[0] + "\t"; + for (int i = 2; i < out.split("\t").length; i++) { + output += out.split("\t")[i] + " "; + } + toreturn.append(output + "\n"); + } + index++; + } + return toreturn.toString(); + } + } + + /** + * compute partial AUC using trapezoid function + */ + public static double getPartialROCArea(Instances tcurve, double limit_inf, double limit_sup) { + int tpr = tcurve.attribute("True Positive Rate").index(); + int fpr = tcurve.attribute("False Positive Rate").index(); + double[] tprVals = tcurve.attributeToDoubleArray(tpr); + double[] fprVals = tcurve.attributeToDoubleArray(fpr); + + //retrive coordinates + //transform list type in arraylist to revert them after + ArrayList<Double> alx = new ArrayList<>(); + ArrayList<Double> aly = new ArrayList<>(); + + for (int i = 0; i < fprVals.length; i++) { + alx.add(fprVals[i]); + aly.add(tprVals[i]); + } + + Collections.reverse(alx); + Collections.reverse(aly); + + //create function points + double[] x = new double[alx.size()]; + double[] y = new double[aly.size()]; + for (int i = 0; i < alx.size(); i++) { + x[i] = alx.get(i); + y[i] = aly.get(i); + } + //correct to avoid multiple x values + try { + for (int i = 1; i < x.length; i++) { + if (x[i] <= x[i - 1]) { + x[i] = x[i - 1] + 0.0001; + int j = i + 1; + while (x[j] == x[i - 1] && j < x.length - 1) { + x[j] += 0.0001; + j++; + } + } + } + } catch (Exception e) { + //e.printStackTrace(); + } + //show curve +// System.out.println("x\ty"); +// for (int i = 0; i < y.length; i++) { +// System.out.println(x[i]+"\t"+y[i]); +// +// } + + //https://stackoverflow.com/questions/16896961/how-to-compute-integration-of-a-function-in-apache-commons-math3-library + //estimate function + UnivariateInterpolator interpolator = new SplineInterpolator(); + UnivariateFunction function = interpolator.interpolate(x, y); + //compute integrale + TrapezoidIntegrator trapezoid = new TrapezoidIntegrator(); + double pauc = 0; + try { + pauc = trapezoid.integrate(10000000, function, limit_inf, 1); + //because of the spline curve, some area can be counter as 0 + if (pauc < 0) { + return 0; + } + } catch (MaxCountExceededException maxCountExceededException) { + //pauc = trapezoid.integrate(10000000, function, limit_inf, limit_sup); + } + + return pauc; + } + + /** + * this object is compatible with bootstrap and repeated holdout method to + * store repeated iterations performances in arrays + */ + public static class evaluationPerformancesResultsObject { + + public ArrayList<Double> alAUCs = new ArrayList<>(); + public ArrayList<Double> alpAUCs = new ArrayList<>(); + public ArrayList<Double> alAUPRCs = new ArrayList<>(); + public ArrayList<Double> alACCs = new ArrayList<>(); + public ArrayList<Double> alSEs = new ArrayList<>(); + public ArrayList<Double> alSPs = new ArrayList<>(); + public ArrayList<Double> alMCCs = new ArrayList<>(); + public ArrayList<Double> alBERs = new ArrayList<>(); + public String meanACCs; + public String meanAUCs; + public String meanAUPRCs; + public String meanSENs; + public String meanSPEs; + public String meanMCCs; + public String meanBERs; + public ArrayList<Double> alCCs = new ArrayList<>(); + public ArrayList<Double> alRMSEs = new ArrayList<>(); + public ArrayList<Double> alMAEs = new ArrayList<>(); + public ArrayList<Double> alRAEs = new ArrayList<>(); + public ArrayList<Double> alRRSEs = new ArrayList<>(); + public String meanCCs; + public String meanMAEs; + public String meanRMSEs; + public String meanRAEs; + public String meanRRSEs; + + public evaluationPerformancesResultsObject() { + } + + public String toStringClassification() { + if (meanACCs == null) { + return "\t\t\t\t\t\t\t"; + } else { + return meanACCs + "\t" + + meanAUCs + "\t" + + meanAUPRCs + "\t" + + meanSENs + "\t" + + meanSPEs + "\t" + + meanMCCs + "\t" + + meanMAEs + "\t" + + meanBERs; + } + } + + public String toStringClassificationDetails() { + return "[score_training] Average weighted ACC: " + meanACCs + + "\t(" + utils.getStandardDeviation(alACCs) + ")\n" + + "[score_training] Average weighted AUC: " + meanAUCs + + "\t(" + utils.getStandardDeviation(alAUCs) + ")\n" + + "[score_training] Average weighted AUPRC: " + meanAUPRCs + + "\t(" + utils.getStandardDeviation(alAUPRCs) + ")\n" + + "[score_training] Average weighted SEN: " + meanSENs + + "\t(" + utils.getStandardDeviation(alSEs) + ")\n" + + "[score_training] Average weighted SPE: " + meanSPEs + + "\t(" + utils.getStandardDeviation(alSPs) + ")\n" + + "[score_training] Average weighted MCC: " + meanMCCs + + "\t(" + utils.getStandardDeviation(alMCCs) + ")\n" + + "[score_training] Average weighted MAE: " + meanMAEs + + "\t(" + utils.getStandardDeviation(alMAEs) + ")\n" + + "[score_training] Average weighted BER: " + meanBERs + + "\t(" + utils.getStandardDeviation(alBERs) + ")"; + } + + public String toStringRegression() { + if (meanCCs == null) { + return "\t\t\t\t"; + } else { + return meanCCs + "\t" + + meanMAEs + "\t" + + meanRMSEs + "\t" + + meanRAEs + "\t" + + meanRRSEs; + } + } + + public String toStringRegressionDetails() { + return "[score_training] Average weighted CC: " + meanCCs + + "\t(" + utils.getStandardDeviation(alCCs) + ")\n" + + "[score_training] Average weighted MAE: " + meanMAEs + + "\t(" + utils.getStandardDeviation(alMAEs) + ")\n" + + "[score_training] Average weighted RMSE: " + meanRMSEs + + "\t(" + utils.getStandardDeviation(alRMSEs) + ")\n" + + "[score_training] Average weighted RAE: " + meanRAEs + + "\t(" + utils.getStandardDeviation(alRAEs) + ")\n" + + "[score_training] Average weighted RRSE: " + meanRRSEs + + "\t(" + utils.getStandardDeviation(alRRSEs) + ")"; + } + + public void computeMeans() { + meanACCs = getMean(alACCs); + meanAUCs = getMean(alAUCs); + meanAUPRCs = getMean(alAUPRCs); + meanSENs = getMean(alSEs); + meanSPEs = getMean(alSPs); + meanMCCs = getMean(alMCCs); + meanBERs = getMean(alBERs); + meanCCs = getMean(alCCs); + meanRMSEs = getMean(alRMSEs); + meanRAEs = getMean(alRAEs); + meanRRSEs = getMean(alRRSEs); + meanMAEs = getMean(alMAEs); + } + } + + public static class testResultsObject { + + public String ACC; + public String AUC; + public String AUPRC; + public String SP; //TNR + public String SE; //TPR + public String MCC; + public String BER; + public String CC; + public String MAE; + public String RMSE; + public String RAE; + public String RRSE; + + public testResultsObject() { + } + + public String toStringClassification() { + return ACC + "\t" + + AUC + "\t" + + AUPRC + "\t" + + SE + "\t" + + SP + "\t" + + MCC + "\t" + + MAE + "\t" + + BER; + } + + public String toStringRegression() { + return CC + "\t" + + MAE + "\t" + + RMSE + "\t" + + RAE + "\t" + + RRSE; + } + + } +}