--- a +++ b/src/biodiscml/rocCurveGraphs.java @@ -0,0 +1,193 @@ +/* + * Create roc curve graph + */ +package biodiscml; + +import java.awt.Color; +import java.io.File; +import java.util.ArrayList; +import java.util.TreeMap; +import org.jfree.chart.ChartFactory; +import org.jfree.chart.ChartUtilities; +import org.jfree.chart.JFreeChart; +import org.jfree.chart.axis.NumberAxis; +import org.jfree.chart.plot.PlotOrientation; +import org.jfree.chart.plot.XYPlot; +import org.jfree.chart.renderer.xy.DeviationRenderer; +import org.jfree.chart.title.LegendTitle; +import org.jfree.data.xy.XYDataset; +import org.jfree.data.xy.XYSeries; +import org.jfree.data.xy.XYSeriesCollection; +import org.jfree.ui.RectangleEdge; +import org.jfree.ui.RectangleInsets; +import utils.Weka_module; +import weka.classifiers.Evaluation; +import weka.classifiers.evaluation.ThresholdCurve; +import weka.core.Instances; + +/** + * + * @author Administrator + */ +public class rocCurveGraphs { + + public static void createRocCurvesWithConfidence(ArrayList<Object> alFoldsROCs, boolean classification, String outfile, String extension) { + TreeMap<String, RocObject> tmFoldsROCs = new TreeMap<>(); + // retreive roc data of each class + for (Object ROC : alFoldsROCs) { + Evaluation eval; + Instances data; + if (classification) { + Weka_module.ClassificationResultsObject cr = (Weka_module.ClassificationResultsObject) ROC; + eval = cr.evaluation; + data = cr.dataset; + } else { + Weka_module.RegressionResultsObject rr = (Weka_module.RegressionResultsObject) ROC; + eval = rr.evaluation; + data = rr.dataset; + } + + boolean hasMoreClasses = true; + int classIndex = 0; + // for each class + for (int i = 0; i < data.numClasses(); i++) { + try { + //get class name + String className = data.classAttribute().value(classIndex); + //get roc object (or create it if first time) + RocObject ro; + if (tmFoldsROCs.containsKey(className)) { + ro = tmFoldsROCs.get(className); + } else { + ro = new RocObject(); + } + //get roc values + Instances rocCurve = new ThresholdCurve().getCurve(eval.predictions(), classIndex); + + //iterate through roc values and store them into Roc objects + ArrayList<Double> alTPR = new ArrayList<>(); + ArrayList<Double> alFPR = new ArrayList<>(); + for (String s : rocCurve.toString().split("\n")) { + if (!s.startsWith("@") && !s.trim().isEmpty()) { + alTPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("True Positive Rate").index()])); + alFPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("False Positive Rate").index()])); + } + } + ro.TPRvalues.add(alTPR); + ro.FPRvalues.add(alFPR); + tmFoldsROCs.put(className, ro); + } catch (Exception e) { + e.printStackTrace(); + hasMoreClasses = false; + } + classIndex++; + } + } + //create picture + XYDataset data = createRocDataset(tmFoldsROCs); + String title = outfile.replace(Main.project + ".", "").replace(Main.wd, ""); + JFreeChart chart = createChart(data, title); + exportChartToPNG(chart, outfile + extension); + } + + public static XYDataset createRocDataset(TreeMap<String, RocObject> tm) { + + XYSeriesCollection dataset = new XYSeriesCollection(); + + //for each fold + for (String classe : tm.keySet()) { + XYSeries serie = new XYSeries(classe, false); + RocObject ro = tm.get(classe); + for (int i = 0; i < ro.FPRvalues.size(); i++) { + //XYSeries serie = new XYSeries(classe + "_fold" + i,false); + for (int j = 0; j < ro.FPRvalues.get(i).size(); j++) { + serie.add(ro.FPRvalues.get(i).get(j), ro.TPRvalues.get(i).get(j)); + } + //dataset.addSeries(serie); + } + dataset.addSeries(serie); + + } + + return dataset; + + } + + public static void exportChartToPNG(JFreeChart chart, String outfile) { + try { + ChartUtilities.saveChartAsPNG(new File(outfile), chart, 1024, 768); + } catch (Exception e) { + e.printStackTrace(); + } + } + + /** + * Creates a chart. + * + * @param dataset the data for the chart. + * @return a chart. + */ + private static JFreeChart createChart(XYDataset dataset, String title) { + // create the chart... + JFreeChart chart = ChartFactory.createXYLineChart( + title, // chart title + "False positive rate (1 - specificity)", // x axis label + "True positive rate (sensitivity)", // y axis label + dataset, // data + PlotOrientation.VERTICAL, + true, // include legend + true, // tooltips + false // urls + ); + + chart.setBackgroundPaint(Color.white); + LegendTitle legend = chart.getLegend(); + legend.setPosition(RectangleEdge.RIGHT); + + // get a reference to the plot for further customisation... + XYPlot plot = (XYPlot) chart.getPlot(); + plot.setBackgroundPaint(Color.white); + plot.setAxisOffset(new RectangleInsets(5.0, 5.0, 5.0, 5.0)); + plot.setDomainGridlinePaint(Color.black); + plot.setRangeGridlinePaint(Color.black); + + //deviation renderer + //color codes: http://htmlcolorcodes.com/ + DeviationRenderer renderer = new DeviationRenderer(true, false); + +// renderer.setSeriesStroke(0, new BasicStroke(3.0f, BasicStroke.CAP_ROUND, +// BasicStroke.JOIN_ROUND)); +// renderer.setSeriesStroke(0, new BasicStroke(3.0f)); +// renderer.setSeriesStroke(1, new BasicStroke(3.0f)); +// renderer.setSeriesFillPaint(0, new Color(255, 118, 118)); //light red +// renderer.setSeriesFillPaint(1, new Color(118, 180, 255)); //light blue + plot.setRenderer(renderer); + +// //smooth line renderer +// XYSplineRenderer rend = new XYSplineRenderer(); +// rend.setPrecision(2); +// plot.setRenderer(rend); + //Axis modifications + NumberAxis yAxis = (NumberAxis) plot.getRangeAxis(); + yAxis.setRange(-0.01, 1.01); + + NumberAxis xAxis = (NumberAxis) plot.getDomainAxis(); + xAxis.setRange(-0.01, 1.01); + + return chart; + } + + private static class RocObject { + + public String classe; + public ArrayList<ArrayList<Double>> TPRvalues; + public ArrayList<ArrayList<Double>> FPRvalues; + + public RocObject() { + TPRvalues = new ArrayList<>(); + FPRvalues = new ArrayList<>(); + + } + } + +}