|
a |
|
b/dataloaders/utils.py |
|
|
1 |
import os |
|
|
2 |
import torch |
|
|
3 |
import numpy as np |
|
|
4 |
import torch.nn as nn |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
from skimage import measure |
|
|
7 |
import scipy.ndimage as nd |
|
|
8 |
from scipy.ndimage import distance_transform_edt as distance |
|
|
9 |
from skimage import segmentation as skimage_seg |
|
|
10 |
|
|
|
11 |
def recursive_glob(rootdir='.', suffix=''): |
|
|
12 |
"""Performs recursive glob with given suffix and rootdir |
|
|
13 |
:param rootdir is the root directory |
|
|
14 |
:param suffix is the suffix to be searched |
|
|
15 |
""" |
|
|
16 |
return [os.path.join(looproot, filename) |
|
|
17 |
for looproot, _, filenames in os.walk(rootdir) |
|
|
18 |
for filename in filenames if filename.endswith(suffix)] |
|
|
19 |
|
|
|
20 |
def get_cityscapes_labels(): |
|
|
21 |
return np.array([ |
|
|
22 |
# [ 0, 0, 0], |
|
|
23 |
[128, 64, 128], |
|
|
24 |
[244, 35, 232], |
|
|
25 |
[70, 70, 70], |
|
|
26 |
[102, 102, 156], |
|
|
27 |
[190, 153, 153], |
|
|
28 |
[153, 153, 153], |
|
|
29 |
[250, 170, 30], |
|
|
30 |
[220, 220, 0], |
|
|
31 |
[107, 142, 35], |
|
|
32 |
[152, 251, 152], |
|
|
33 |
[0, 130, 180], |
|
|
34 |
[220, 20, 60], |
|
|
35 |
[255, 0, 0], |
|
|
36 |
[0, 0, 142], |
|
|
37 |
[0, 0, 70], |
|
|
38 |
[0, 60, 100], |
|
|
39 |
[0, 80, 100], |
|
|
40 |
[0, 0, 230], |
|
|
41 |
[119, 11, 32]]) |
|
|
42 |
|
|
|
43 |
def get_pascal_labels(): |
|
|
44 |
"""Load the mapping that associates pascal classes with label colors |
|
|
45 |
Returns: |
|
|
46 |
np.ndarray with dimensions (21, 3) |
|
|
47 |
""" |
|
|
48 |
return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], |
|
|
49 |
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], |
|
|
50 |
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], |
|
|
51 |
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], |
|
|
52 |
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], |
|
|
53 |
[0, 64, 128]]) |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
def encode_segmap(mask): |
|
|
57 |
"""Encode segmentation label images as pascal classes |
|
|
58 |
Args: |
|
|
59 |
mask (np.ndarray): raw segmentation label image of dimension |
|
|
60 |
(M, N, 3), in which the Pascal classes are encoded as colours. |
|
|
61 |
Returns: |
|
|
62 |
(np.ndarray): class map with dimensions (M,N), where the value at |
|
|
63 |
a given location is the integer denoting the class index. |
|
|
64 |
""" |
|
|
65 |
mask = mask.astype(int) |
|
|
66 |
label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) |
|
|
67 |
for ii, label in enumerate(get_pascal_labels()): |
|
|
68 |
label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii |
|
|
69 |
label_mask = label_mask.astype(int) |
|
|
70 |
return label_mask |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
def decode_seg_map_sequence(label_masks, dataset='pascal'): |
|
|
74 |
rgb_masks = [] |
|
|
75 |
for label_mask in label_masks: |
|
|
76 |
rgb_mask = decode_segmap(label_mask, dataset) |
|
|
77 |
rgb_masks.append(rgb_mask) |
|
|
78 |
rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) |
|
|
79 |
return rgb_masks |
|
|
80 |
|
|
|
81 |
def decode_segmap(label_mask, dataset, plot=False): |
|
|
82 |
"""Decode segmentation class labels into a color image |
|
|
83 |
Args: |
|
|
84 |
label_mask (np.ndarray): an (M,N) array of integer values denoting |
|
|
85 |
the class label at each spatial location. |
|
|
86 |
plot (bool, optional): whether to show the resulting color image |
|
|
87 |
in a figure. |
|
|
88 |
Returns: |
|
|
89 |
(np.ndarray, optional): the resulting decoded color image. |
|
|
90 |
""" |
|
|
91 |
if dataset == 'pascal': |
|
|
92 |
n_classes = 21 |
|
|
93 |
label_colours = get_pascal_labels() |
|
|
94 |
elif dataset == 'cityscapes': |
|
|
95 |
n_classes = 19 |
|
|
96 |
label_colours = get_cityscapes_labels() |
|
|
97 |
else: |
|
|
98 |
raise NotImplementedError |
|
|
99 |
|
|
|
100 |
r = label_mask.copy() |
|
|
101 |
g = label_mask.copy() |
|
|
102 |
b = label_mask.copy() |
|
|
103 |
for ll in range(0, n_classes): |
|
|
104 |
r[label_mask == ll] = label_colours[ll, 0] |
|
|
105 |
g[label_mask == ll] = label_colours[ll, 1] |
|
|
106 |
b[label_mask == ll] = label_colours[ll, 2] |
|
|
107 |
rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) |
|
|
108 |
rgb[:, :, 0] = r / 255.0 |
|
|
109 |
rgb[:, :, 1] = g / 255.0 |
|
|
110 |
rgb[:, :, 2] = b / 255.0 |
|
|
111 |
if plot: |
|
|
112 |
plt.imshow(rgb) |
|
|
113 |
plt.show() |
|
|
114 |
else: |
|
|
115 |
return rgb |
|
|
116 |
|
|
|
117 |
def generate_param_report(logfile, param): |
|
|
118 |
log_file = open(logfile, 'w') |
|
|
119 |
# for key, val in param.items(): |
|
|
120 |
# log_file.write(key + ':' + str(val) + '\n') |
|
|
121 |
log_file.write(str(param)) |
|
|
122 |
log_file.close() |
|
|
123 |
|
|
|
124 |
def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): |
|
|
125 |
n, c, h, w = logit.size() |
|
|
126 |
# logit = logit.permute(0, 2, 3, 1) |
|
|
127 |
target = target.squeeze(1) |
|
|
128 |
if weight is None: |
|
|
129 |
criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) |
|
|
130 |
else: |
|
|
131 |
criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False) |
|
|
132 |
loss = criterion(logit, target.long()) |
|
|
133 |
|
|
|
134 |
if size_average: |
|
|
135 |
loss /= (h * w) |
|
|
136 |
|
|
|
137 |
if batch_average: |
|
|
138 |
loss /= n |
|
|
139 |
|
|
|
140 |
return loss |
|
|
141 |
|
|
|
142 |
def lr_poly(base_lr, iter_, max_iter=100, power=0.9): |
|
|
143 |
return base_lr * ((1 - float(iter_) / max_iter) ** power) |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
def get_iou(pred, gt, n_classes=21): |
|
|
147 |
total_iou = 0.0 |
|
|
148 |
for i in range(len(pred)): |
|
|
149 |
pred_tmp = pred[i] |
|
|
150 |
gt_tmp = gt[i] |
|
|
151 |
|
|
|
152 |
intersect = [0] * n_classes |
|
|
153 |
union = [0] * n_classes |
|
|
154 |
for j in range(n_classes): |
|
|
155 |
match = (pred_tmp == j) + (gt_tmp == j) |
|
|
156 |
|
|
|
157 |
it = torch.sum(match == 2).item() |
|
|
158 |
un = torch.sum(match > 0).item() |
|
|
159 |
|
|
|
160 |
intersect[j] += it |
|
|
161 |
union[j] += un |
|
|
162 |
|
|
|
163 |
iou = [] |
|
|
164 |
for k in range(n_classes): |
|
|
165 |
if union[k] == 0: |
|
|
166 |
continue |
|
|
167 |
iou.append(intersect[k] / union[k]) |
|
|
168 |
|
|
|
169 |
img_iou = (sum(iou) / len(iou)) |
|
|
170 |
total_iou += img_iou |
|
|
171 |
|
|
|
172 |
return total_iou |
|
|
173 |
|
|
|
174 |
def get_dice(pred, gt): |
|
|
175 |
total_dice = 0.0 |
|
|
176 |
pred = pred.long() |
|
|
177 |
gt = gt.long() |
|
|
178 |
for i in range(len(pred)): |
|
|
179 |
pred_tmp = pred[i] |
|
|
180 |
gt_tmp = gt[i] |
|
|
181 |
dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() |
|
|
182 |
print(dice) |
|
|
183 |
total_dice += dice |
|
|
184 |
|
|
|
185 |
return total_dice |
|
|
186 |
|
|
|
187 |
def get_mc_dice(pred, gt, num=2): |
|
|
188 |
# num is the total number of classes, include the background |
|
|
189 |
total_dice = np.zeros(num-1) |
|
|
190 |
pred = pred.long() |
|
|
191 |
gt = gt.long() |
|
|
192 |
for i in range(len(pred)): |
|
|
193 |
for j in range(1, num): |
|
|
194 |
pred_tmp = (pred[i]==j) |
|
|
195 |
gt_tmp = (gt[i]==j) |
|
|
196 |
dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() |
|
|
197 |
total_dice[j-1] +=dice |
|
|
198 |
return total_dice |
|
|
199 |
|
|
|
200 |
def post_processing(prediction): |
|
|
201 |
prediction = nd.binary_fill_holes(prediction) |
|
|
202 |
label_cc, num_cc = measure.label(prediction,return_num=True) |
|
|
203 |
total_cc = np.sum(prediction) |
|
|
204 |
measure.regionprops(label_cc) |
|
|
205 |
for cc in range(1,num_cc+1): |
|
|
206 |
single_cc = (label_cc==cc) |
|
|
207 |
single_vol = np.sum(single_cc) |
|
|
208 |
if single_vol/total_cc<0.2: |
|
|
209 |
prediction[single_cc]=0 |
|
|
210 |
|
|
|
211 |
return prediction |
|
|
212 |
|
|
|
213 |
def compute_sdf(img_gt, out_shape): |
|
|
214 |
""" |
|
|
215 |
compute the signed distance map of binary mask |
|
|
216 |
input: segmentation, shape = (batch_size, x, y, z) |
|
|
217 |
output: the Signed Distance Map (SDM) |
|
|
218 |
sdf(x) = 0; x in segmentation boundary |
|
|
219 |
-inf|x-y|; x in segmentation |
|
|
220 |
+inf|x-y|; x out of segmentation |
|
|
221 |
normalize sdf to [-1,1] |
|
|
222 |
""" |
|
|
223 |
|
|
|
224 |
img_gt = img_gt.astype(np.uint8) |
|
|
225 |
normalized_sdf = np.zeros(out_shape) |
|
|
226 |
|
|
|
227 |
for b in range(out_shape[0]): # batch size |
|
|
228 |
posmask = img_gt[b].astype(np.bool) |
|
|
229 |
if posmask.any(): |
|
|
230 |
negmask = ~posmask |
|
|
231 |
posdis = distance(posmask) |
|
|
232 |
negdis = distance(negmask) |
|
|
233 |
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) |
|
|
234 |
sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) |
|
|
235 |
sdf[boundary==1] = 0 |
|
|
236 |
normalized_sdf[b] = sdf |
|
|
237 |
assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis)) |
|
|
238 |
assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) |
|
|
239 |
|
|
|
240 |
return normalized_sdf |