|
a |
|
b/src/biodiscml/rocCurveGraphs.java |
|
|
1 |
/* |
|
|
2 |
* Create roc curve graph |
|
|
3 |
*/ |
|
|
4 |
package biodiscml; |
|
|
5 |
|
|
|
6 |
import java.awt.Color; |
|
|
7 |
import java.io.File; |
|
|
8 |
import java.util.ArrayList; |
|
|
9 |
import java.util.TreeMap; |
|
|
10 |
import org.jfree.chart.ChartFactory; |
|
|
11 |
import org.jfree.chart.ChartUtilities; |
|
|
12 |
import org.jfree.chart.JFreeChart; |
|
|
13 |
import org.jfree.chart.axis.NumberAxis; |
|
|
14 |
import org.jfree.chart.plot.PlotOrientation; |
|
|
15 |
import org.jfree.chart.plot.XYPlot; |
|
|
16 |
import org.jfree.chart.renderer.xy.DeviationRenderer; |
|
|
17 |
import org.jfree.chart.title.LegendTitle; |
|
|
18 |
import org.jfree.data.xy.XYDataset; |
|
|
19 |
import org.jfree.data.xy.XYSeries; |
|
|
20 |
import org.jfree.data.xy.XYSeriesCollection; |
|
|
21 |
import org.jfree.ui.RectangleEdge; |
|
|
22 |
import org.jfree.ui.RectangleInsets; |
|
|
23 |
import utils.Weka_module; |
|
|
24 |
import weka.classifiers.Evaluation; |
|
|
25 |
import weka.classifiers.evaluation.ThresholdCurve; |
|
|
26 |
import weka.core.Instances; |
|
|
27 |
|
|
|
28 |
/** |
|
|
29 |
* |
|
|
30 |
* @author Administrator |
|
|
31 |
*/ |
|
|
32 |
public class rocCurveGraphs { |
|
|
33 |
|
|
|
34 |
public static void createRocCurvesWithConfidence(ArrayList<Object> alFoldsROCs, boolean classification, String outfile, String extension) { |
|
|
35 |
TreeMap<String, RocObject> tmFoldsROCs = new TreeMap<>(); |
|
|
36 |
// retreive roc data of each class |
|
|
37 |
for (Object ROC : alFoldsROCs) { |
|
|
38 |
Evaluation eval; |
|
|
39 |
Instances data; |
|
|
40 |
if (classification) { |
|
|
41 |
Weka_module.ClassificationResultsObject cr = (Weka_module.ClassificationResultsObject) ROC; |
|
|
42 |
eval = cr.evaluation; |
|
|
43 |
data = cr.dataset; |
|
|
44 |
} else { |
|
|
45 |
Weka_module.RegressionResultsObject rr = (Weka_module.RegressionResultsObject) ROC; |
|
|
46 |
eval = rr.evaluation; |
|
|
47 |
data = rr.dataset; |
|
|
48 |
} |
|
|
49 |
|
|
|
50 |
boolean hasMoreClasses = true; |
|
|
51 |
int classIndex = 0; |
|
|
52 |
// for each class |
|
|
53 |
for (int i = 0; i < data.numClasses(); i++) { |
|
|
54 |
try { |
|
|
55 |
//get class name |
|
|
56 |
String className = data.classAttribute().value(classIndex); |
|
|
57 |
//get roc object (or create it if first time) |
|
|
58 |
RocObject ro; |
|
|
59 |
if (tmFoldsROCs.containsKey(className)) { |
|
|
60 |
ro = tmFoldsROCs.get(className); |
|
|
61 |
} else { |
|
|
62 |
ro = new RocObject(); |
|
|
63 |
} |
|
|
64 |
//get roc values |
|
|
65 |
Instances rocCurve = new ThresholdCurve().getCurve(eval.predictions(), classIndex); |
|
|
66 |
|
|
|
67 |
//iterate through roc values and store them into Roc objects |
|
|
68 |
ArrayList<Double> alTPR = new ArrayList<>(); |
|
|
69 |
ArrayList<Double> alFPR = new ArrayList<>(); |
|
|
70 |
for (String s : rocCurve.toString().split("\n")) { |
|
|
71 |
if (!s.startsWith("@") && !s.trim().isEmpty()) { |
|
|
72 |
alTPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("True Positive Rate").index()])); |
|
|
73 |
alFPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("False Positive Rate").index()])); |
|
|
74 |
} |
|
|
75 |
} |
|
|
76 |
ro.TPRvalues.add(alTPR); |
|
|
77 |
ro.FPRvalues.add(alFPR); |
|
|
78 |
tmFoldsROCs.put(className, ro); |
|
|
79 |
} catch (Exception e) { |
|
|
80 |
e.printStackTrace(); |
|
|
81 |
hasMoreClasses = false; |
|
|
82 |
} |
|
|
83 |
classIndex++; |
|
|
84 |
} |
|
|
85 |
} |
|
|
86 |
//create picture |
|
|
87 |
XYDataset data = createRocDataset(tmFoldsROCs); |
|
|
88 |
String title = outfile.replace(Main.project + ".", "").replace(Main.wd, ""); |
|
|
89 |
JFreeChart chart = createChart(data, title); |
|
|
90 |
exportChartToPNG(chart, outfile + extension); |
|
|
91 |
} |
|
|
92 |
|
|
|
93 |
public static XYDataset createRocDataset(TreeMap<String, RocObject> tm) { |
|
|
94 |
|
|
|
95 |
XYSeriesCollection dataset = new XYSeriesCollection(); |
|
|
96 |
|
|
|
97 |
//for each fold |
|
|
98 |
for (String classe : tm.keySet()) { |
|
|
99 |
XYSeries serie = new XYSeries(classe, false); |
|
|
100 |
RocObject ro = tm.get(classe); |
|
|
101 |
for (int i = 0; i < ro.FPRvalues.size(); i++) { |
|
|
102 |
//XYSeries serie = new XYSeries(classe + "_fold" + i,false); |
|
|
103 |
for (int j = 0; j < ro.FPRvalues.get(i).size(); j++) { |
|
|
104 |
serie.add(ro.FPRvalues.get(i).get(j), ro.TPRvalues.get(i).get(j)); |
|
|
105 |
} |
|
|
106 |
//dataset.addSeries(serie); |
|
|
107 |
} |
|
|
108 |
dataset.addSeries(serie); |
|
|
109 |
|
|
|
110 |
} |
|
|
111 |
|
|
|
112 |
return dataset; |
|
|
113 |
|
|
|
114 |
} |
|
|
115 |
|
|
|
116 |
public static void exportChartToPNG(JFreeChart chart, String outfile) { |
|
|
117 |
try { |
|
|
118 |
ChartUtilities.saveChartAsPNG(new File(outfile), chart, 1024, 768); |
|
|
119 |
} catch (Exception e) { |
|
|
120 |
e.printStackTrace(); |
|
|
121 |
} |
|
|
122 |
} |
|
|
123 |
|
|
|
124 |
/** |
|
|
125 |
* Creates a chart. |
|
|
126 |
* |
|
|
127 |
* @param dataset the data for the chart. |
|
|
128 |
* @return a chart. |
|
|
129 |
*/ |
|
|
130 |
private static JFreeChart createChart(XYDataset dataset, String title) { |
|
|
131 |
// create the chart... |
|
|
132 |
JFreeChart chart = ChartFactory.createXYLineChart( |
|
|
133 |
title, // chart title |
|
|
134 |
"False positive rate (1 - specificity)", // x axis label |
|
|
135 |
"True positive rate (sensitivity)", // y axis label |
|
|
136 |
dataset, // data |
|
|
137 |
PlotOrientation.VERTICAL, |
|
|
138 |
true, // include legend |
|
|
139 |
true, // tooltips |
|
|
140 |
false // urls |
|
|
141 |
); |
|
|
142 |
|
|
|
143 |
chart.setBackgroundPaint(Color.white); |
|
|
144 |
LegendTitle legend = chart.getLegend(); |
|
|
145 |
legend.setPosition(RectangleEdge.RIGHT); |
|
|
146 |
|
|
|
147 |
// get a reference to the plot for further customisation... |
|
|
148 |
XYPlot plot = (XYPlot) chart.getPlot(); |
|
|
149 |
plot.setBackgroundPaint(Color.white); |
|
|
150 |
plot.setAxisOffset(new RectangleInsets(5.0, 5.0, 5.0, 5.0)); |
|
|
151 |
plot.setDomainGridlinePaint(Color.black); |
|
|
152 |
plot.setRangeGridlinePaint(Color.black); |
|
|
153 |
|
|
|
154 |
//deviation renderer |
|
|
155 |
//color codes: http://htmlcolorcodes.com/ |
|
|
156 |
DeviationRenderer renderer = new DeviationRenderer(true, false); |
|
|
157 |
|
|
|
158 |
// renderer.setSeriesStroke(0, new BasicStroke(3.0f, BasicStroke.CAP_ROUND, |
|
|
159 |
// BasicStroke.JOIN_ROUND)); |
|
|
160 |
// renderer.setSeriesStroke(0, new BasicStroke(3.0f)); |
|
|
161 |
// renderer.setSeriesStroke(1, new BasicStroke(3.0f)); |
|
|
162 |
// renderer.setSeriesFillPaint(0, new Color(255, 118, 118)); //light red |
|
|
163 |
// renderer.setSeriesFillPaint(1, new Color(118, 180, 255)); //light blue |
|
|
164 |
plot.setRenderer(renderer); |
|
|
165 |
|
|
|
166 |
// //smooth line renderer |
|
|
167 |
// XYSplineRenderer rend = new XYSplineRenderer(); |
|
|
168 |
// rend.setPrecision(2); |
|
|
169 |
// plot.setRenderer(rend); |
|
|
170 |
//Axis modifications |
|
|
171 |
NumberAxis yAxis = (NumberAxis) plot.getRangeAxis(); |
|
|
172 |
yAxis.setRange(-0.01, 1.01); |
|
|
173 |
|
|
|
174 |
NumberAxis xAxis = (NumberAxis) plot.getDomainAxis(); |
|
|
175 |
xAxis.setRange(-0.01, 1.01); |
|
|
176 |
|
|
|
177 |
return chart; |
|
|
178 |
} |
|
|
179 |
|
|
|
180 |
private static class RocObject { |
|
|
181 |
|
|
|
182 |
public String classe; |
|
|
183 |
public ArrayList<ArrayList<Double>> TPRvalues; |
|
|
184 |
public ArrayList<ArrayList<Double>> FPRvalues; |
|
|
185 |
|
|
|
186 |
public RocObject() { |
|
|
187 |
TPRvalues = new ArrayList<>(); |
|
|
188 |
FPRvalues = new ArrayList<>(); |
|
|
189 |
|
|
|
190 |
} |
|
|
191 |
} |
|
|
192 |
|
|
|
193 |
} |