--- a +++ b/dataloaders/utils.py @@ -0,0 +1,240 @@ +import os +import torch +import numpy as np +import torch.nn as nn +import matplotlib.pyplot as plt +from skimage import measure +import scipy.ndimage as nd +from scipy.ndimage import distance_transform_edt as distance +from skimage import segmentation as skimage_seg + +def recursive_glob(rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + +def get_cityscapes_labels(): + return np.array([ + # [ 0, 0, 0], + [128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [0, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32]]) + +def get_pascal_labels(): + """Load the mapping that associates pascal classes with label colors + Returns: + np.ndarray with dimensions (21, 3) + """ + return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], + [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], + [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], + [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], + [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], + [0, 64, 128]]) + + +def encode_segmap(mask): + """Encode segmentation label images as pascal classes + Args: + mask (np.ndarray): raw segmentation label image of dimension + (M, N, 3), in which the Pascal classes are encoded as colours. + Returns: + (np.ndarray): class map with dimensions (M,N), where the value at + a given location is the integer denoting the class index. + """ + mask = mask.astype(int) + label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) + for ii, label in enumerate(get_pascal_labels()): + label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii + label_mask = label_mask.astype(int) + return label_mask + + +def decode_seg_map_sequence(label_masks, dataset='pascal'): + rgb_masks = [] + for label_mask in label_masks: + rgb_mask = decode_segmap(label_mask, dataset) + rgb_masks.append(rgb_mask) + rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) + return rgb_masks + +def decode_segmap(label_mask, dataset, plot=False): + """Decode segmentation class labels into a color image + Args: + label_mask (np.ndarray): an (M,N) array of integer values denoting + the class label at each spatial location. + plot (bool, optional): whether to show the resulting color image + in a figure. + Returns: + (np.ndarray, optional): the resulting decoded color image. + """ + if dataset == 'pascal': + n_classes = 21 + label_colours = get_pascal_labels() + elif dataset == 'cityscapes': + n_classes = 19 + label_colours = get_cityscapes_labels() + else: + raise NotImplementedError + + r = label_mask.copy() + g = label_mask.copy() + b = label_mask.copy() + for ll in range(0, n_classes): + r[label_mask == ll] = label_colours[ll, 0] + g[label_mask == ll] = label_colours[ll, 1] + b[label_mask == ll] = label_colours[ll, 2] + rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) + rgb[:, :, 0] = r / 255.0 + rgb[:, :, 1] = g / 255.0 + rgb[:, :, 2] = b / 255.0 + if plot: + plt.imshow(rgb) + plt.show() + else: + return rgb + +def generate_param_report(logfile, param): + log_file = open(logfile, 'w') + # for key, val in param.items(): + # log_file.write(key + ':' + str(val) + '\n') + log_file.write(str(param)) + log_file.close() + +def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): + n, c, h, w = logit.size() + # logit = logit.permute(0, 2, 3, 1) + target = target.squeeze(1) + if weight is None: + criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) + else: + criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False) + loss = criterion(logit, target.long()) + + if size_average: + loss /= (h * w) + + if batch_average: + loss /= n + + return loss + +def lr_poly(base_lr, iter_, max_iter=100, power=0.9): + return base_lr * ((1 - float(iter_) / max_iter) ** power) + + +def get_iou(pred, gt, n_classes=21): + total_iou = 0.0 + for i in range(len(pred)): + pred_tmp = pred[i] + gt_tmp = gt[i] + + intersect = [0] * n_classes + union = [0] * n_classes + for j in range(n_classes): + match = (pred_tmp == j) + (gt_tmp == j) + + it = torch.sum(match == 2).item() + un = torch.sum(match > 0).item() + + intersect[j] += it + union[j] += un + + iou = [] + for k in range(n_classes): + if union[k] == 0: + continue + iou.append(intersect[k] / union[k]) + + img_iou = (sum(iou) / len(iou)) + total_iou += img_iou + + return total_iou + +def get_dice(pred, gt): + total_dice = 0.0 + pred = pred.long() + gt = gt.long() + for i in range(len(pred)): + pred_tmp = pred[i] + gt_tmp = gt[i] + dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() + print(dice) + total_dice += dice + + return total_dice + +def get_mc_dice(pred, gt, num=2): + # num is the total number of classes, include the background + total_dice = np.zeros(num-1) + pred = pred.long() + gt = gt.long() + for i in range(len(pred)): + for j in range(1, num): + pred_tmp = (pred[i]==j) + gt_tmp = (gt[i]==j) + dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() + total_dice[j-1] +=dice + return total_dice + +def post_processing(prediction): + prediction = nd.binary_fill_holes(prediction) + label_cc, num_cc = measure.label(prediction,return_num=True) + total_cc = np.sum(prediction) + measure.regionprops(label_cc) + for cc in range(1,num_cc+1): + single_cc = (label_cc==cc) + single_vol = np.sum(single_cc) + if single_vol/total_cc<0.2: + prediction[single_cc]=0 + + return prediction + +def compute_sdf(img_gt, out_shape): + """ + compute the signed distance map of binary mask + input: segmentation, shape = (batch_size, x, y, z) + output: the Signed Distance Map (SDM) + sdf(x) = 0; x in segmentation boundary + -inf|x-y|; x in segmentation + +inf|x-y|; x out of segmentation + normalize sdf to [-1,1] + """ + + img_gt = img_gt.astype(np.uint8) + normalized_sdf = np.zeros(out_shape) + + for b in range(out_shape[0]): # batch size + posmask = img_gt[b].astype(np.bool) + if posmask.any(): + negmask = ~posmask + posdis = distance(posmask) + negdis = distance(negmask) + boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8) + sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis)) + sdf[boundary==1] = 0 + normalized_sdf[b] = sdf + assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis)) + assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) + + return normalized_sdf