Diff of /dataloaders/utils.py [000000] .. [903821]

Switch to unified view

a b/dataloaders/utils.py
1
import os
2
import torch
3
import numpy as np
4
import torch.nn as nn
5
import matplotlib.pyplot as plt
6
from skimage import measure
7
import scipy.ndimage as nd
8
from scipy.ndimage import distance_transform_edt as distance
9
from skimage import segmentation as skimage_seg
10
11
def recursive_glob(rootdir='.', suffix=''):
12
    """Performs recursive glob with given suffix and rootdir
13
        :param rootdir is the root directory
14
        :param suffix is the suffix to be searched
15
    """
16
    return [os.path.join(looproot, filename)
17
        for looproot, _, filenames in os.walk(rootdir)
18
        for filename in filenames if filename.endswith(suffix)]
19
20
def get_cityscapes_labels():
21
    return np.array([
22
        # [  0,   0,   0],
23
        [128, 64, 128],
24
        [244, 35, 232],
25
        [70, 70, 70],
26
        [102, 102, 156],
27
        [190, 153, 153],
28
        [153, 153, 153],
29
        [250, 170, 30],
30
        [220, 220, 0],
31
        [107, 142, 35],
32
        [152, 251, 152],
33
        [0, 130, 180],
34
        [220, 20, 60],
35
        [255, 0, 0],
36
        [0, 0, 142],
37
        [0, 0, 70],
38
        [0, 60, 100],
39
        [0, 80, 100],
40
        [0, 0, 230],
41
        [119, 11, 32]])
42
43
def get_pascal_labels():
44
    """Load the mapping that associates pascal classes with label colors
45
    Returns:
46
        np.ndarray with dimensions (21, 3)
47
    """
48
    return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
49
                       [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
50
                       [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
51
                       [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
52
                       [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
53
                       [0, 64, 128]])
54
55
56
def encode_segmap(mask):
57
    """Encode segmentation label images as pascal classes
58
    Args:
59
        mask (np.ndarray): raw segmentation label image of dimension
60
          (M, N, 3), in which the Pascal classes are encoded as colours.
61
    Returns:
62
        (np.ndarray): class map with dimensions (M,N), where the value at
63
        a given location is the integer denoting the class index.
64
    """
65
    mask = mask.astype(int)
66
    label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
67
    for ii, label in enumerate(get_pascal_labels()):
68
        label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
69
    label_mask = label_mask.astype(int)
70
    return label_mask
71
72
73
def decode_seg_map_sequence(label_masks, dataset='pascal'):
74
    rgb_masks = []
75
    for label_mask in label_masks:
76
        rgb_mask = decode_segmap(label_mask, dataset)
77
        rgb_masks.append(rgb_mask)
78
    rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
79
    return rgb_masks
80
81
def decode_segmap(label_mask, dataset, plot=False):
82
    """Decode segmentation class labels into a color image
83
    Args:
84
        label_mask (np.ndarray): an (M,N) array of integer values denoting
85
          the class label at each spatial location.
86
        plot (bool, optional): whether to show the resulting color image
87
          in a figure.
88
    Returns:
89
        (np.ndarray, optional): the resulting decoded color image.
90
    """
91
    if dataset == 'pascal':
92
        n_classes = 21
93
        label_colours = get_pascal_labels()
94
    elif dataset == 'cityscapes':
95
        n_classes = 19
96
        label_colours = get_cityscapes_labels()
97
    else:
98
        raise NotImplementedError
99
100
    r = label_mask.copy()
101
    g = label_mask.copy()
102
    b = label_mask.copy()
103
    for ll in range(0, n_classes):
104
        r[label_mask == ll] = label_colours[ll, 0]
105
        g[label_mask == ll] = label_colours[ll, 1]
106
        b[label_mask == ll] = label_colours[ll, 2]
107
    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
108
    rgb[:, :, 0] = r / 255.0
109
    rgb[:, :, 1] = g / 255.0
110
    rgb[:, :, 2] = b / 255.0
111
    if plot:
112
        plt.imshow(rgb)
113
        plt.show()
114
    else:
115
        return rgb
116
117
def generate_param_report(logfile, param):
118
    log_file = open(logfile, 'w')
119
    # for key, val in param.items():
120
    #     log_file.write(key + ':' + str(val) + '\n')
121
    log_file.write(str(param))
122
    log_file.close()
123
124
def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True):
125
    n, c, h, w = logit.size()
126
    # logit = logit.permute(0, 2, 3, 1)
127
    target = target.squeeze(1)
128
    if weight is None:
129
        criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False)
130
    else:
131
        criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False)
132
    loss = criterion(logit, target.long())
133
134
    if size_average:
135
        loss /= (h * w)
136
137
    if batch_average:
138
        loss /= n
139
140
    return loss
141
142
def lr_poly(base_lr, iter_, max_iter=100, power=0.9):
143
    return base_lr * ((1 - float(iter_) / max_iter) ** power)
144
145
146
def get_iou(pred, gt, n_classes=21):
147
    total_iou = 0.0
148
    for i in range(len(pred)):
149
        pred_tmp = pred[i]
150
        gt_tmp = gt[i]
151
152
        intersect = [0] * n_classes
153
        union = [0] * n_classes
154
        for j in range(n_classes):
155
            match = (pred_tmp == j) + (gt_tmp == j)
156
157
            it = torch.sum(match == 2).item()
158
            un = torch.sum(match > 0).item()
159
160
            intersect[j] += it
161
            union[j] += un
162
163
        iou = []
164
        for k in range(n_classes):
165
            if union[k] == 0:
166
                continue
167
            iou.append(intersect[k] / union[k])
168
169
        img_iou = (sum(iou) / len(iou))
170
        total_iou += img_iou
171
172
    return total_iou
173
174
def get_dice(pred, gt):
175
    total_dice = 0.0
176
    pred = pred.long()
177
    gt = gt.long()
178
    for i in range(len(pred)):
179
        pred_tmp = pred[i]
180
        gt_tmp = gt[i]
181
        dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item()
182
        print(dice)
183
        total_dice += dice
184
185
    return total_dice
186
187
def get_mc_dice(pred, gt, num=2):
188
    # num is the total number of classes, include the background
189
    total_dice = np.zeros(num-1)
190
    pred = pred.long()
191
    gt = gt.long()
192
    for i in range(len(pred)):
193
        for j in range(1, num):
194
            pred_tmp = (pred[i]==j)
195
            gt_tmp = (gt[i]==j)
196
            dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item()
197
            total_dice[j-1] +=dice
198
    return total_dice
199
200
def post_processing(prediction):
201
    prediction = nd.binary_fill_holes(prediction)
202
    label_cc, num_cc = measure.label(prediction,return_num=True)
203
    total_cc = np.sum(prediction)
204
    measure.regionprops(label_cc)
205
    for cc in range(1,num_cc+1):
206
        single_cc = (label_cc==cc)
207
        single_vol = np.sum(single_cc)
208
        if single_vol/total_cc<0.2:
209
            prediction[single_cc]=0
210
211
    return prediction
212
213
def compute_sdf(img_gt, out_shape):
214
    """
215
    compute the signed distance map of binary mask
216
    input: segmentation, shape = (batch_size, x, y, z)
217
    output: the Signed Distance Map (SDM)
218
    sdf(x) = 0; x in segmentation boundary
219
             -inf|x-y|; x in segmentation
220
             +inf|x-y|; x out of segmentation
221
    normalize sdf to [-1,1]
222
    """
223
224
    img_gt = img_gt.astype(np.uint8)
225
    normalized_sdf = np.zeros(out_shape)
226
227
    for b in range(out_shape[0]): # batch size
228
        posmask = img_gt[b].astype(np.bool)
229
        if posmask.any():
230
            negmask = ~posmask
231
            posdis = distance(posmask)
232
            negdis = distance(negmask)
233
            boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
234
            sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis))
235
            sdf[boundary==1] = 0
236
            normalized_sdf[b] = sdf
237
            assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis))
238
            assert np.max(sdf) ==  1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
239
240
    return normalized_sdf