Switch to unified view

a b/dsb2018_topcoders/selim/metric.py
1
from multiprocessing.pool import Pool
2
3
import numpy as np
4
from skimage import measure
5
6
7
def calculate_cell_score_kaggle(y_true, y_pred, num_threads=32):
8
    yps = []
9
    for m in range(len(y_true)):
10
        yps.append((y_true[m].copy(), y_pred[m].copy()))
11
    pool = Pool(num_threads)
12
    results = pool.map(score_kaggle, yps)
13
    return np.mean(results)
14
15
16
def calculate_cell_score_selim(y_true, y_pred, num_threads=32, ids=None):
17
    yps = []
18
    for m in range(len(y_true)):
19
        yps.append((y_true[m].copy(), y_pred[m].copy()))
20
    pool = Pool(num_threads)
21
    results = pool.map(calculate_jaccard, yps)
22
    if ids:
23
        import pandas as pd
24
        s_iou = np.argsort(results)
25
        d = []
26
        for i in range(len(s_iou)):
27
            id = ids[s_iou[i]]
28
            res = results[s_iou[i]]
29
            d.append([id, res])
30
            pd.DataFrame(d, columns=["ID", "METRIC_SCORE"]).to_csv("gt_vs_oof.csv", index=False)
31
32
    return np.array(results).mean()
33
34
35
def get_cells(mask):
36
    return measure.label(mask, return_num=True)
37
38
def score_kaggle(yp):
39
    y, p = yp
40
    return calc_score(np.expand_dims(y, 0), np.expand_dims(p, 0))
41
42
43
44
def calc_score(labels, y_pred):
45
    true_objects = len(np.unique(labels))
46
    pred_objects = len(np.unique(y_pred))
47
    #    print("Number of true objects:", true_objects)
48
    #    print("Number of predicted objects:", pred_objects)
49
    # Compute intersection between all objects
50
    intersection = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0]
51
52
    # Compute areas (needed for finding the union between all objects)
53
    area_true = np.histogram(labels, bins=true_objects)[0]
54
    area_pred = np.histogram(y_pred, bins=pred_objects)[0]
55
    area_true = np.expand_dims(area_true, -1)
56
    area_pred = np.expand_dims(area_pred, 0)
57
58
    # Compute union
59
    union = area_true + area_pred - intersection
60
61
    # Exclude background from the analysis
62
    intersection = intersection[1:, 1:]
63
    union = union[1:, 1:]
64
    union[union == 0] = 1e-9
65
66
    # Compute the intersection over union
67
    iou = intersection / union
68
69
    # Precision helper function
70
    def precision_at(threshold, iou):
71
        matches = iou > threshold
72
        true_positives = np.sum(matches, axis=1) == 1  # Correct objects
73
        false_positives = np.sum(matches, axis=0) == 0  # Missed objects
74
        false_negatives = np.sum(matches, axis=1) == 0  # Extra objects
75
        tp, fp, fn = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)
76
        return tp, fp, fn
77
78
    # Loop over IoU thresholds
79
    prec = []
80
    #    print("Thresh\tTP\tFP\tFN\tPrec.")
81
    for t in np.arange(0.5, 1.0, 0.05):
82
        tp, fp, fn = precision_at(t, iou)
83
        p = tp / (tp + fp + fn)
84
        #        print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tp, fp, fn, p))
85
        prec.append(p)
86
    #    print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))
87
    return np.mean(prec)
88
89
90
def calculate_jaccard(yps):
91
    y, p = yps
92
    jaccards = []
93
    iou_thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]
94
    for iou_threshold in iou_thresholds:
95
        tp = 0
96
        fp = 0
97
        fn = 0
98
        processed_gt = set()
99
        matched = set()
100
        size = p.shape[0], p.shape[1]
101
        mask_img = np.reshape(p, size)
102
103
        gt_mask_img = np.reshape(y, size)
104
        predicted_labels, predicted_count = get_cells(mask_img)
105
        gt_labels, gt_count = get_cells(gt_mask_img)
106
107
        gt_cells = [rp.coords for rp in measure.regionprops(gt_labels)]
108
        pred_cells = [rp.coords for rp in measure.regionprops(predicted_labels)]
109
        gt_cells = [to_point_set(b) for b in gt_cells]
110
        pred_cells = [to_point_set(b) for b in pred_cells]
111
        for j in range(predicted_count):
112
            match_found = False
113
            for i in range(gt_count):
114
                pred_ind = j + 1
115
                gt_ind = i + 1
116
                if match_found:
117
                    break
118
                if gt_ind in processed_gt:
119
                    continue
120
                pred_cell = pred_cells[j]
121
                gt_cell = gt_cells[i]
122
                intersection = len(pred_cell.intersection(gt_cell))
123
                union = len(pred_cell) + len(gt_cell) - intersection
124
                iou = intersection / union
125
                if iou > iou_threshold:
126
                    processed_gt.add(gt_ind)
127
                    matched.add(pred_ind)
128
                    match_found = True
129
                    tp += 1
130
            if not match_found:
131
                fp += 1
132
        fn += gt_count - len(processed_gt)
133
        jaccards.append(tp / (tp + fp + fn))
134
    return np.mean(jaccards)
135
136
137
def to_point_set(cell):
138
    return set([(row[0], row[1]) for row in cell])