[b758a2]: / matlab / plot_prcurve.m

Download this file

116 lines (100 with data), 2.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
function plot_prcurve(res_folder, gt_folder)
% folders must end with /
close all
% parameters
%d = 32.573; % a detection within d pixels of the ground truth is considered correct
d = 20.358;
%d = 10;
threshold = 0.1:0.01:1;
precision = zeros(1,size(threshold,2));
recall = zeros(1,size(threshold,2));
f1 = zeros(1,size(threshold,2));
res_listing = dir([res_folder '*.csv']);
g = waitbar(0);
for h=1:size(threshold,2)
true_pos = 0;
false_neg = 0;
total_pos = 0;
for i=1:size(res_listing)
results = csvread([res_folder res_listing(i).name]);
% cut results off at the threshold
for j=1:size(results,1)
if results(j,3) < threshold(h)
results = results(1:j-1,:);
break
end
end
total_pos = total_pos + size(results,1);
[slide, remain] = strtok(res_listing(i).name, '_');
subfolder = [gt_folder slide '_v2/'];
ground_truth = csvread([subfolder res_listing(i).name]);
% find the centroids
len_gt = size(ground_truth,1);
tmp_gt = zeros(len_gt,2);
for j=1:len_gt;
tmp = ground_truth(j,:);
tmp = tmp(tmp~=0);
len_tmp = size(tmp,2);
sum_X = 0;
sum_Y = 0;
for k=1:len_tmp/2
sum_X = sum_X + tmp(2*k-1);
sum_Y = sum_Y + tmp(2*k);
end
mean_X = 2*sum_X/len_tmp;
mean_Y = 2*sum_Y/len_tmp;
tmp_gt(j,:) = [mean_Y mean_X];
end
ground_truth = round(tmp_gt);
len_res = size(results,1);
res_tally = zeros(len_res,1);
for j=1:len_gt
false_neg = false_neg + 1;
for k=1:len_res
if (ground_truth(j,1) - results(k,1))^2 + ...
(ground_truth(j,2) - results(k,2))^2 < d^2 ...
&& res_tally(k) == 0
true_pos = true_pos + 1;
false_neg = false_neg - 1;
res_tally(k) = 1;
break
end
end
end
end
precision(h) = true_pos/total_pos;
recall(h) = true_pos/(true_pos + false_neg);
f1(h) = 2*precision(h)*recall(h)/(precision(h) + recall(h));
waitbar(h/size(threshold,2));
end
close(g)
[M,I] = max(f1);
disp(['The F1 score reaches a maximum of ' num2str(M) ' when the threshold is ' num2str(threshold(I))])
figure
plot(precision, recall)
title('precision-recall curve')
xlabel('precision')
ylabel('recall')
xlim([0 1])
ylim([0 1])
figure
plot(threshold, f1)
title('f1 score')
xlabel('threshold')
ylabel('f1 score')
xlim([0 1])
ylim([0 1])
figure
plot(threshold, precision)
title('precision')
xlabel('threshold')
ylabel('precision')
xlim([0 1])
ylim([0 1])
figure
plot(threshold, recall)
title('recall')
xlabel('threshold')
ylabel('recall')
xlim([0 1])
ylim([0 1])