Switch to side-by-side view

--- a
+++ b/src/biodiscml/rocCurveGraphs.java
@@ -0,0 +1,193 @@
+/*
+ * Create roc curve graph
+ */
+package biodiscml;
+
+import java.awt.Color;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.TreeMap;
+import org.jfree.chart.ChartFactory;
+import org.jfree.chart.ChartUtilities;
+import org.jfree.chart.JFreeChart;
+import org.jfree.chart.axis.NumberAxis;
+import org.jfree.chart.plot.PlotOrientation;
+import org.jfree.chart.plot.XYPlot;
+import org.jfree.chart.renderer.xy.DeviationRenderer;
+import org.jfree.chart.title.LegendTitle;
+import org.jfree.data.xy.XYDataset;
+import org.jfree.data.xy.XYSeries;
+import org.jfree.data.xy.XYSeriesCollection;
+import org.jfree.ui.RectangleEdge;
+import org.jfree.ui.RectangleInsets;
+import utils.Weka_module;
+import weka.classifiers.Evaluation;
+import weka.classifiers.evaluation.ThresholdCurve;
+import weka.core.Instances;
+
+/**
+ *
+ * @author Administrator
+ */
+public class rocCurveGraphs {
+
+    public static void createRocCurvesWithConfidence(ArrayList<Object> alFoldsROCs, boolean classification, String outfile, String extension) {
+        TreeMap<String, RocObject> tmFoldsROCs = new TreeMap<>();
+        // retreive roc data of each class
+        for (Object ROC : alFoldsROCs) {
+            Evaluation eval;
+            Instances data;
+            if (classification) {
+                Weka_module.ClassificationResultsObject cr = (Weka_module.ClassificationResultsObject) ROC;
+                eval = cr.evaluation;
+                data = cr.dataset;
+            } else {
+                Weka_module.RegressionResultsObject rr = (Weka_module.RegressionResultsObject) ROC;
+                eval = rr.evaluation;
+                data = rr.dataset;
+            }
+
+            boolean hasMoreClasses = true;
+            int classIndex = 0;
+            // for each class
+            for (int i = 0; i < data.numClasses(); i++) {
+                try {
+                    //get class name
+                    String className = data.classAttribute().value(classIndex);
+                    //get roc object (or create it if first time)
+                    RocObject ro;
+                    if (tmFoldsROCs.containsKey(className)) {
+                        ro = tmFoldsROCs.get(className);
+                    } else {
+                        ro = new RocObject();
+                    }
+                    //get roc values
+                    Instances rocCurve = new ThresholdCurve().getCurve(eval.predictions(), classIndex);
+
+                    //iterate through roc values and store them into Roc objects
+                    ArrayList<Double> alTPR = new ArrayList<>();
+                    ArrayList<Double> alFPR = new ArrayList<>();
+                    for (String s : rocCurve.toString().split("\n")) {
+                        if (!s.startsWith("@") && !s.trim().isEmpty()) {
+                            alTPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("True Positive Rate").index()]));
+                            alFPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("False Positive Rate").index()]));
+                        }
+                    }
+                    ro.TPRvalues.add(alTPR);
+                    ro.FPRvalues.add(alFPR);
+                    tmFoldsROCs.put(className, ro);
+                } catch (Exception e) {
+                    e.printStackTrace();
+                    hasMoreClasses = false;
+                }
+                classIndex++;
+            }
+        }
+        //create picture
+        XYDataset data = createRocDataset(tmFoldsROCs);
+        String title = outfile.replace(Main.project + ".", "").replace(Main.wd, "");
+        JFreeChart chart = createChart(data, title);
+        exportChartToPNG(chart, outfile + extension);
+    }
+
+    public static XYDataset createRocDataset(TreeMap<String, RocObject> tm) {
+
+        XYSeriesCollection dataset = new XYSeriesCollection();
+
+        //for each fold
+        for (String classe : tm.keySet()) {
+            XYSeries serie = new XYSeries(classe, false);
+            RocObject ro = tm.get(classe);
+            for (int i = 0; i < ro.FPRvalues.size(); i++) {
+                //XYSeries serie = new XYSeries(classe + "_fold" + i,false);
+                for (int j = 0; j < ro.FPRvalues.get(i).size(); j++) {
+                    serie.add(ro.FPRvalues.get(i).get(j), ro.TPRvalues.get(i).get(j));
+                }
+                //dataset.addSeries(serie);
+            }
+            dataset.addSeries(serie);
+
+        }
+
+        return dataset;
+
+    }
+
+    public static void exportChartToPNG(JFreeChart chart, String outfile) {
+        try {
+            ChartUtilities.saveChartAsPNG(new File(outfile), chart, 1024, 768);
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+    }
+
+    /**
+     * Creates a chart.
+     *
+     * @param dataset the data for the chart.
+     * @return a chart.
+     */
+    private static JFreeChart createChart(XYDataset dataset, String title) {
+        // create the chart...
+        JFreeChart chart = ChartFactory.createXYLineChart(
+                title, // chart title
+                "False positive rate (1 - specificity)", // x axis label
+                "True positive rate (sensitivity)", // y axis label
+                dataset, // data
+                PlotOrientation.VERTICAL,
+                true, // include legend
+                true, // tooltips
+                false // urls
+        );
+
+        chart.setBackgroundPaint(Color.white);
+        LegendTitle legend = chart.getLegend();
+        legend.setPosition(RectangleEdge.RIGHT);
+
+        // get a reference to the plot for further customisation...
+        XYPlot plot = (XYPlot) chart.getPlot();
+        plot.setBackgroundPaint(Color.white);
+        plot.setAxisOffset(new RectangleInsets(5.0, 5.0, 5.0, 5.0));
+        plot.setDomainGridlinePaint(Color.black);
+        plot.setRangeGridlinePaint(Color.black);
+
+        //deviation renderer
+        //color codes: http://htmlcolorcodes.com/
+        DeviationRenderer renderer = new DeviationRenderer(true, false);
+
+//        renderer.setSeriesStroke(0, new BasicStroke(3.0f, BasicStroke.CAP_ROUND,
+//                BasicStroke.JOIN_ROUND));
+//        renderer.setSeriesStroke(0, new BasicStroke(3.0f));
+//        renderer.setSeriesStroke(1, new BasicStroke(3.0f));
+//        renderer.setSeriesFillPaint(0, new Color(255, 118, 118)); //light red
+//        renderer.setSeriesFillPaint(1, new Color(118, 180, 255)); //light blue
+        plot.setRenderer(renderer);
+
+//        //smooth line renderer
+//        XYSplineRenderer rend = new XYSplineRenderer();
+//        rend.setPrecision(2);
+//        plot.setRenderer(rend);
+        //Axis modifications
+        NumberAxis yAxis = (NumberAxis) plot.getRangeAxis();
+        yAxis.setRange(-0.01, 1.01);
+
+        NumberAxis xAxis = (NumberAxis) plot.getDomainAxis();
+        xAxis.setRange(-0.01, 1.01);
+
+        return chart;
+    }
+
+    private static class RocObject {
+
+        public String classe;
+        public ArrayList<ArrayList<Double>> TPRvalues;
+        public ArrayList<ArrayList<Double>> FPRvalues;
+
+        public RocObject() {
+            TPRvalues = new ArrayList<>();
+            FPRvalues = new ArrayList<>();
+
+        }
+    }
+
+}