Switch to unified view

a b/src/biodiscml/rocCurveGraphs.java
1
/*
2
 * Create roc curve graph
3
 */
4
package biodiscml;
5
6
import java.awt.Color;
7
import java.io.File;
8
import java.util.ArrayList;
9
import java.util.TreeMap;
10
import org.jfree.chart.ChartFactory;
11
import org.jfree.chart.ChartUtilities;
12
import org.jfree.chart.JFreeChart;
13
import org.jfree.chart.axis.NumberAxis;
14
import org.jfree.chart.plot.PlotOrientation;
15
import org.jfree.chart.plot.XYPlot;
16
import org.jfree.chart.renderer.xy.DeviationRenderer;
17
import org.jfree.chart.title.LegendTitle;
18
import org.jfree.data.xy.XYDataset;
19
import org.jfree.data.xy.XYSeries;
20
import org.jfree.data.xy.XYSeriesCollection;
21
import org.jfree.ui.RectangleEdge;
22
import org.jfree.ui.RectangleInsets;
23
import utils.Weka_module;
24
import weka.classifiers.Evaluation;
25
import weka.classifiers.evaluation.ThresholdCurve;
26
import weka.core.Instances;
27
28
/**
29
 *
30
 * @author Administrator
31
 */
32
public class rocCurveGraphs {
33
34
    public static void createRocCurvesWithConfidence(ArrayList<Object> alFoldsROCs, boolean classification, String outfile, String extension) {
35
        TreeMap<String, RocObject> tmFoldsROCs = new TreeMap<>();
36
        // retreive roc data of each class
37
        for (Object ROC : alFoldsROCs) {
38
            Evaluation eval;
39
            Instances data;
40
            if (classification) {
41
                Weka_module.ClassificationResultsObject cr = (Weka_module.ClassificationResultsObject) ROC;
42
                eval = cr.evaluation;
43
                data = cr.dataset;
44
            } else {
45
                Weka_module.RegressionResultsObject rr = (Weka_module.RegressionResultsObject) ROC;
46
                eval = rr.evaluation;
47
                data = rr.dataset;
48
            }
49
50
            boolean hasMoreClasses = true;
51
            int classIndex = 0;
52
            // for each class
53
            for (int i = 0; i < data.numClasses(); i++) {
54
                try {
55
                    //get class name
56
                    String className = data.classAttribute().value(classIndex);
57
                    //get roc object (or create it if first time)
58
                    RocObject ro;
59
                    if (tmFoldsROCs.containsKey(className)) {
60
                        ro = tmFoldsROCs.get(className);
61
                    } else {
62
                        ro = new RocObject();
63
                    }
64
                    //get roc values
65
                    Instances rocCurve = new ThresholdCurve().getCurve(eval.predictions(), classIndex);
66
67
                    //iterate through roc values and store them into Roc objects
68
                    ArrayList<Double> alTPR = new ArrayList<>();
69
                    ArrayList<Double> alFPR = new ArrayList<>();
70
                    for (String s : rocCurve.toString().split("\n")) {
71
                        if (!s.startsWith("@") && !s.trim().isEmpty()) {
72
                            alTPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("True Positive Rate").index()]));
73
                            alFPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("False Positive Rate").index()]));
74
                        }
75
                    }
76
                    ro.TPRvalues.add(alTPR);
77
                    ro.FPRvalues.add(alFPR);
78
                    tmFoldsROCs.put(className, ro);
79
                } catch (Exception e) {
80
                    e.printStackTrace();
81
                    hasMoreClasses = false;
82
                }
83
                classIndex++;
84
            }
85
        }
86
        //create picture
87
        XYDataset data = createRocDataset(tmFoldsROCs);
88
        String title = outfile.replace(Main.project + ".", "").replace(Main.wd, "");
89
        JFreeChart chart = createChart(data, title);
90
        exportChartToPNG(chart, outfile + extension);
91
    }
92
93
    public static XYDataset createRocDataset(TreeMap<String, RocObject> tm) {
94
95
        XYSeriesCollection dataset = new XYSeriesCollection();
96
97
        //for each fold
98
        for (String classe : tm.keySet()) {
99
            XYSeries serie = new XYSeries(classe, false);
100
            RocObject ro = tm.get(classe);
101
            for (int i = 0; i < ro.FPRvalues.size(); i++) {
102
                //XYSeries serie = new XYSeries(classe + "_fold" + i,false);
103
                for (int j = 0; j < ro.FPRvalues.get(i).size(); j++) {
104
                    serie.add(ro.FPRvalues.get(i).get(j), ro.TPRvalues.get(i).get(j));
105
                }
106
                //dataset.addSeries(serie);
107
            }
108
            dataset.addSeries(serie);
109
110
        }
111
112
        return dataset;
113
114
    }
115
116
    public static void exportChartToPNG(JFreeChart chart, String outfile) {
117
        try {
118
            ChartUtilities.saveChartAsPNG(new File(outfile), chart, 1024, 768);
119
        } catch (Exception e) {
120
            e.printStackTrace();
121
        }
122
    }
123
124
    /**
125
     * Creates a chart.
126
     *
127
     * @param dataset the data for the chart.
128
     * @return a chart.
129
     */
130
    private static JFreeChart createChart(XYDataset dataset, String title) {
131
        // create the chart...
132
        JFreeChart chart = ChartFactory.createXYLineChart(
133
                title, // chart title
134
                "False positive rate (1 - specificity)", // x axis label
135
                "True positive rate (sensitivity)", // y axis label
136
                dataset, // data
137
                PlotOrientation.VERTICAL,
138
                true, // include legend
139
                true, // tooltips
140
                false // urls
141
        );
142
143
        chart.setBackgroundPaint(Color.white);
144
        LegendTitle legend = chart.getLegend();
145
        legend.setPosition(RectangleEdge.RIGHT);
146
147
        // get a reference to the plot for further customisation...
148
        XYPlot plot = (XYPlot) chart.getPlot();
149
        plot.setBackgroundPaint(Color.white);
150
        plot.setAxisOffset(new RectangleInsets(5.0, 5.0, 5.0, 5.0));
151
        plot.setDomainGridlinePaint(Color.black);
152
        plot.setRangeGridlinePaint(Color.black);
153
154
        //deviation renderer
155
        //color codes: http://htmlcolorcodes.com/
156
        DeviationRenderer renderer = new DeviationRenderer(true, false);
157
158
//        renderer.setSeriesStroke(0, new BasicStroke(3.0f, BasicStroke.CAP_ROUND,
159
//                BasicStroke.JOIN_ROUND));
160
//        renderer.setSeriesStroke(0, new BasicStroke(3.0f));
161
//        renderer.setSeriesStroke(1, new BasicStroke(3.0f));
162
//        renderer.setSeriesFillPaint(0, new Color(255, 118, 118)); //light red
163
//        renderer.setSeriesFillPaint(1, new Color(118, 180, 255)); //light blue
164
        plot.setRenderer(renderer);
165
166
//        //smooth line renderer
167
//        XYSplineRenderer rend = new XYSplineRenderer();
168
//        rend.setPrecision(2);
169
//        plot.setRenderer(rend);
170
        //Axis modifications
171
        NumberAxis yAxis = (NumberAxis) plot.getRangeAxis();
172
        yAxis.setRange(-0.01, 1.01);
173
174
        NumberAxis xAxis = (NumberAxis) plot.getDomainAxis();
175
        xAxis.setRange(-0.01, 1.01);
176
177
        return chart;
178
    }
179
180
    private static class RocObject {
181
182
        public String classe;
183
        public ArrayList<ArrayList<Double>> TPRvalues;
184
        public ArrayList<ArrayList<Double>> FPRvalues;
185
186
        public RocObject() {
187
            TPRvalues = new ArrayList<>();
188
            FPRvalues = new ArrayList<>();
189
190
        }
191
    }
192
193
}