[ce076b]: / src / biodiscml / rocCurveGraphs.java

Download this file

194 lines (168 with data), 7.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
/*
* Create roc curve graph
*/
package biodiscml;
import java.awt.Color;
import java.io.File;
import java.util.ArrayList;
import java.util.TreeMap;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.DeviationRenderer;
import org.jfree.chart.title.LegendTitle;
import org.jfree.data.xy.XYDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.jfree.ui.RectangleEdge;
import org.jfree.ui.RectangleInsets;
import utils.Weka_module;
import weka.classifiers.Evaluation;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.core.Instances;
/**
*
* @author Administrator
*/
public class rocCurveGraphs {
public static void createRocCurvesWithConfidence(ArrayList<Object> alFoldsROCs, boolean classification, String outfile, String extension) {
TreeMap<String, RocObject> tmFoldsROCs = new TreeMap<>();
// retreive roc data of each class
for (Object ROC : alFoldsROCs) {
Evaluation eval;
Instances data;
if (classification) {
Weka_module.ClassificationResultsObject cr = (Weka_module.ClassificationResultsObject) ROC;
eval = cr.evaluation;
data = cr.dataset;
} else {
Weka_module.RegressionResultsObject rr = (Weka_module.RegressionResultsObject) ROC;
eval = rr.evaluation;
data = rr.dataset;
}
boolean hasMoreClasses = true;
int classIndex = 0;
// for each class
for (int i = 0; i < data.numClasses(); i++) {
try {
//get class name
String className = data.classAttribute().value(classIndex);
//get roc object (or create it if first time)
RocObject ro;
if (tmFoldsROCs.containsKey(className)) {
ro = tmFoldsROCs.get(className);
} else {
ro = new RocObject();
}
//get roc values
Instances rocCurve = new ThresholdCurve().getCurve(eval.predictions(), classIndex);
//iterate through roc values and store them into Roc objects
ArrayList<Double> alTPR = new ArrayList<>();
ArrayList<Double> alFPR = new ArrayList<>();
for (String s : rocCurve.toString().split("\n")) {
if (!s.startsWith("@") && !s.trim().isEmpty()) {
alTPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("True Positive Rate").index()]));
alFPR.add(Double.parseDouble(s.split(",")[rocCurve.attribute("False Positive Rate").index()]));
}
}
ro.TPRvalues.add(alTPR);
ro.FPRvalues.add(alFPR);
tmFoldsROCs.put(className, ro);
} catch (Exception e) {
e.printStackTrace();
hasMoreClasses = false;
}
classIndex++;
}
}
//create picture
XYDataset data = createRocDataset(tmFoldsROCs);
String title = outfile.replace(Main.project + ".", "").replace(Main.wd, "");
JFreeChart chart = createChart(data, title);
exportChartToPNG(chart, outfile + extension);
}
public static XYDataset createRocDataset(TreeMap<String, RocObject> tm) {
XYSeriesCollection dataset = new XYSeriesCollection();
//for each fold
for (String classe : tm.keySet()) {
XYSeries serie = new XYSeries(classe, false);
RocObject ro = tm.get(classe);
for (int i = 0; i < ro.FPRvalues.size(); i++) {
//XYSeries serie = new XYSeries(classe + "_fold" + i,false);
for (int j = 0; j < ro.FPRvalues.get(i).size(); j++) {
serie.add(ro.FPRvalues.get(i).get(j), ro.TPRvalues.get(i).get(j));
}
//dataset.addSeries(serie);
}
dataset.addSeries(serie);
}
return dataset;
}
public static void exportChartToPNG(JFreeChart chart, String outfile) {
try {
ChartUtilities.saveChartAsPNG(new File(outfile), chart, 1024, 768);
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* Creates a chart.
*
* @param dataset the data for the chart.
* @return a chart.
*/
private static JFreeChart createChart(XYDataset dataset, String title) {
// create the chart...
JFreeChart chart = ChartFactory.createXYLineChart(
title, // chart title
"False positive rate (1 - specificity)", // x axis label
"True positive rate (sensitivity)", // y axis label
dataset, // data
PlotOrientation.VERTICAL,
true, // include legend
true, // tooltips
false // urls
);
chart.setBackgroundPaint(Color.white);
LegendTitle legend = chart.getLegend();
legend.setPosition(RectangleEdge.RIGHT);
// get a reference to the plot for further customisation...
XYPlot plot = (XYPlot) chart.getPlot();
plot.setBackgroundPaint(Color.white);
plot.setAxisOffset(new RectangleInsets(5.0, 5.0, 5.0, 5.0));
plot.setDomainGridlinePaint(Color.black);
plot.setRangeGridlinePaint(Color.black);
//deviation renderer
//color codes: http://htmlcolorcodes.com/
DeviationRenderer renderer = new DeviationRenderer(true, false);
// renderer.setSeriesStroke(0, new BasicStroke(3.0f, BasicStroke.CAP_ROUND,
// BasicStroke.JOIN_ROUND));
// renderer.setSeriesStroke(0, new BasicStroke(3.0f));
// renderer.setSeriesStroke(1, new BasicStroke(3.0f));
// renderer.setSeriesFillPaint(0, new Color(255, 118, 118)); //light red
// renderer.setSeriesFillPaint(1, new Color(118, 180, 255)); //light blue
plot.setRenderer(renderer);
// //smooth line renderer
// XYSplineRenderer rend = new XYSplineRenderer();
// rend.setPrecision(2);
// plot.setRenderer(rend);
//Axis modifications
NumberAxis yAxis = (NumberAxis) plot.getRangeAxis();
yAxis.setRange(-0.01, 1.01);
NumberAxis xAxis = (NumberAxis) plot.getDomainAxis();
xAxis.setRange(-0.01, 1.01);
return chart;
}
private static class RocObject {
public String classe;
public ArrayList<ArrayList<Double>> TPRvalues;
public ArrayList<ArrayList<Double>> FPRvalues;
public RocObject() {
TPRvalues = new ArrayList<>();
FPRvalues = new ArrayList<>();
}
}
}