Diff of /matlab/plot_prcurve.m [000000] .. [b758a2]

Switch to unified view

a b/matlab/plot_prcurve.m
1
function plot_prcurve(res_folder, gt_folder)
2
3
% folders must end with /
4
5
close all
6
7
% parameters
8
%d = 32.573;         % a detection within d pixels of the ground truth is considered correct
9
d = 20.358;
10
%d = 10;
11
threshold = 0.1:0.01:1;
12
precision = zeros(1,size(threshold,2));
13
recall = zeros(1,size(threshold,2));
14
f1 = zeros(1,size(threshold,2));
15
16
res_listing = dir([res_folder '*.csv']);
17
18
g = waitbar(0);
19
for h=1:size(threshold,2)
20
    true_pos = 0;
21
    false_neg = 0;
22
    total_pos = 0;
23
    for i=1:size(res_listing)
24
        results = csvread([res_folder res_listing(i).name]);
25
        % cut results off at the threshold
26
        for j=1:size(results,1)
27
            if results(j,3) < threshold(h)
28
                results = results(1:j-1,:);
29
                break
30
            end
31
        end
32
        total_pos = total_pos + size(results,1);
33
34
        [slide, remain] = strtok(res_listing(i).name, '_');
35
        subfolder = [gt_folder slide '_v2/'];
36
        ground_truth = csvread([subfolder res_listing(i).name]);
37
38
        % find the centroids
39
        len_gt = size(ground_truth,1);
40
        tmp_gt = zeros(len_gt,2);
41
        for j=1:len_gt;
42
            tmp = ground_truth(j,:);
43
            tmp = tmp(tmp~=0);
44
            len_tmp = size(tmp,2);
45
            sum_X = 0;
46
            sum_Y = 0;
47
            for k=1:len_tmp/2
48
                sum_X = sum_X + tmp(2*k-1);
49
                sum_Y = sum_Y + tmp(2*k);
50
            end
51
            mean_X = 2*sum_X/len_tmp;
52
            mean_Y = 2*sum_Y/len_tmp;
53
            tmp_gt(j,:) = [mean_Y mean_X];
54
        end
55
        ground_truth = round(tmp_gt);
56
57
        len_res = size(results,1);
58
        res_tally = zeros(len_res,1);
59
        for j=1:len_gt
60
            false_neg = false_neg + 1;
61
            for k=1:len_res
62
                if (ground_truth(j,1) - results(k,1))^2 + ...
63
                        (ground_truth(j,2) - results(k,2))^2 < d^2 ...
64
                        && res_tally(k) == 0
65
                    true_pos = true_pos + 1;
66
                    false_neg = false_neg - 1;
67
                    res_tally(k) = 1;
68
                    break
69
                end
70
            end
71
        end
72
    end
73
74
    precision(h) = true_pos/total_pos;
75
    recall(h) = true_pos/(true_pos + false_neg);
76
    f1(h) = 2*precision(h)*recall(h)/(precision(h) + recall(h));
77
    
78
    waitbar(h/size(threshold,2));
79
end
80
close(g)
81
82
[M,I] = max(f1);
83
disp(['The F1 score reaches a maximum of ' num2str(M) ' when the threshold is ' num2str(threshold(I))])
84
85
figure
86
plot(precision, recall)
87
title('precision-recall curve')
88
xlabel('precision')
89
ylabel('recall')
90
xlim([0 1])
91
ylim([0 1])
92
93
figure
94
plot(threshold, f1)
95
title('f1 score')
96
xlabel('threshold')
97
ylabel('f1 score')
98
xlim([0 1])
99
ylim([0 1])
100
101
figure
102
plot(threshold, precision)
103
title('precision')
104
xlabel('threshold')
105
ylabel('precision')
106
xlim([0 1])
107
ylim([0 1])
108
109
figure
110
plot(threshold, recall)
111
title('recall')
112
xlabel('threshold')
113
ylabel('recall')
114
xlim([0 1])
115
ylim([0 1])