--- a +++ b/utils.py @@ -0,0 +1,322 @@ +from collections import defaultdict +from os.path import join +from random import randint +from scipy import ndimage +from statistics import median +import numpy +import os +import shutil +import sys + +from torch import nn +import torch +import nibabel as nib + + +def transfer_weights(target_model, saved_model): + """ + target_model: a model instance whose weight params are to be overwritten + saved_model: a model whose weight params will be transfered to target. + saved_model can be a string(path to a snapshot), an instance of model + or a state dict of a model + """ + target_dict = target_model.state_dict() + if isinstance(saved_model, str): + source_dict = torch.load(saved_model) + else: + source_dict = saved_model + if not isinstance(source_dict, dict): + source_dict = source_dict.state_dict() + source_dict = {k: v for k, v in source_dict.items() if + k in target_model.state_dict() and source_dict[k].size() == target_model.state_dict()[k].size()} + target_dict.update(source_dict) + target_model.load_state_dict(target_dict) + + +def generate_ex_list(directory): + """ + Generate list of MRI objects + """ + inputs = [] + labels = [] + for dirpath, dirs, files in os.walk(directory): + label_list = list() + for file in files: + if not file.startswith('.') and file.endswith('.nii.gz'): + if ("Lesion" in file): + label_list.append(join(dirpath, file)) + elif ("mask" not in file): + inputs.append(join(dirpath, file)) + if label_list: + labels.append(label_list) + + return inputs, labels + + +def gen_mask(lesion_files): + """ + Given a list of lesion files, generate a mask + that incorporates data from all of them + """ + first_lesion = nib.load(lesion_files[0]).get_data() + if len(lesion_files) == 1: + return first_lesion + lesion_data = numpy.zeros((first_lesion.shape[0], first_lesion.shape[1], first_lesion.shape[2])) + for file in lesion_files: + l_file = correct_dims(nib.load(file).get_data()) + lesion_data = numpy.maximum(l_file, lesion_data) + return lesion_data + + +def correct_dims(img): + """ + Fix the dimension of the image, if necessary + """ + if len(img.shape) > 3: + img = img.reshape(img.shape[0], img.shape[1], img.shape[2]) + return img + + +def get_weight_vector(labels, weight, is_cuda): + """ Generates the weight vector for BCE loss + You can only control positive weight, and negative weight is + default to 1. + So if ratio of positive and negative samples are 1:3, + then give weight 3, and this functio returns 3 for positive and + 1 for negative samples. + """ + if is_cuda: + labels = labels.cpu() + labels = labels.data.numpy() + labels = labels * (weight-1) + 1 + weight_label = torch.from_numpy(labels).type(torch.FloatTensor) + if is_cuda: + weight_label = weight_label.cuda() + return weight_label + + +def resize_img(input_img, label_img, size): + """ + size: int or list of int + when it's a list, it should include x, y, z values + Resize image to (size x size x size) + """ + if isinstance(size, int): + size = [size]*3 + assert len(size) == 3 + ax1 = float(size[0]) / input_img.shape[0] + ax2 = float(size[1]) / input_img.shape[1] + ax3 = float(size[2]) / input_img.shape[2] + ex = ndimage.zoom(input_img, (ax1, ax2, ax3)) + label = ndimage.zoom(label_img, (ax1, ax2, ax3)) + return ex, label + + +def center_crop(input_img, label_img, size): + """ + Crop center section from image + size: int or list of int + when it's a list, it should include x, y, z values + Use for testing. + """ + if isinstance(size, int): + size = [size]*3 + assert len(size) == 3 + coords = [0]*3 + for i in range(3): + coords[i] = int((input_img.shape[i]-size[i])//2) + x, y, z = coords + ex = input_img[x:x+size[0], y:y+size[1], z:z+size[2]] + label = label_img[x:x+size[0], y:y+size[1], z:z+size[2]] + return ex, label + + +def find_and_crop_lesions(input_img, label_img, size, deterministic=False): + """ + Find and crop image based on center of lesions + size: int or list of int + when it's a list, it should include x, y, z values + Use for validation. + """ + if isinstance(size, int): + size = [size]*3 + assert len(size) == 3 + nonzeros = label_img.nonzero() + d = [0]*3 + if not deterministic: + for i in range(3): + d[i] = randint(-size[i]//4, size[i]//4) + + coords = [0]*3 + for i in range(3): + coords[i] = max(min(int(median(nonzeros[i])) - (size[i] // 2) + d[i], input_img.shape[i] - size[i] - 1), 0) + x, y, z = coords + ex = input_img[x:x+size[0], y:y+size[1], z:z+size[2]] + label = label_img[x:x+size[0], y:y+size[1], z:z+size[2]] + return ex, label + + +def random_crop(input_img, label_img, size, remove_background=False): + """ + Crop random section from image + size: int or list of int + when it's a list, it should include x, y, z values + remove_background: boolean + use this option when input contains larger background or crop size is very small + Use for training + """ + if isinstance(size, int): + size = [size]*3 + assert len(size) == 3 + non_zero_percentage = 0 + while non_zero_percentage < 0.7: + """draw x,y,z coords + """ + coords = [0]*3 + for i in range(3): + coords[i] = numpy.random.choice(input_img.shape[i] - size[i]) + x, y, z = coords + ex = input_img[x:x+size[0], y:y+size[1], z:z+size[2]] + non_zero_percentage = numpy.count_nonzero(ex) / float(size[0]*size[1]*size[2]) + if not remove_background: + break + if non_zero_percentage < 0.7: + del ex + + label = label_img[x:x+size[0], y:y+size[1], z:z+size[2]] + return ex, label + + +class Report: + EPS = sys.float_info.epsilon + TP_KEY = 0 + TN_KEY = 1 + FP_KEY = 2 + FN_KEY = 3 + + def __init__(self, threshold=0.5, smooth=sys.float_info.epsilon, apply_square=False, need_feedback=False): + """ + apply_square: use squared elements in the denominator of soft Dice + need_feedback: returns a tensor storing KEYS(0 to 3) for each output element + """ + self.pos = 0 + self.neg = 0 + self.false_pos = 0 + self.false_neg = 0 + self.true_pos = 0 + self.true_neg = 0 + self.soft_I = 0 + self.soft_U = 0 + self.hard_I = 0 + self.hard_U = 0 + self.smooth = smooth + self.apply_square = apply_square # this variable: mainly for testing + self.need_feedback = need_feedback + self.threshold = threshold + self.pathdic = defaultdict(list) + + def feed(self, pred, label, paths=None): + """ pred size: batch x dim1 x dim2 x... + label size: batch x dim1 x dim2 x... + First dim should be a batch size + """ + self.soft_I += (pred * label).sum().item() + power_coeff = 2 if self.apply_square else 1 + if power_coeff == 1: + self.soft_U += (pred.sum() + label.sum()).item() + else: + self.soft_U += (pred.pow(power_coeff).sum() + label.pow(power_coeff).sum()).item() + pred = pred.view(-1) + label = label.view(-1) + pred = (pred > self.threshold).squeeze() + not_pred = (pred == 0).squeeze() + label = label.byte().squeeze() + not_label = (label == 0).squeeze() + self.pos += label.sum().item() + self.neg += not_label.sum().item() + pxl = pred * label + self.hard_I += (pxl).sum().item() + self.hard_U += (pred.sum() + label.sum()).item() + pxnl = pred * not_label + fp = (pxnl).sum().item() + self.false_pos += fp + npxl = not_pred * label + fn = (npxl).sum().item() + self.false_neg += fn + tp = (pxl).sum().item() + self.true_pos += tp + npxnl = not_pred * not_label + tn = (npxnl).sum().item() + self.true_neg += tn + + feedback = None + if self.need_feedback: + feedback = pxl*self.TP_KEY +\ + npxnl*self.TN_KEY +\ + pxnl*self.FP_KEY +\ + npxl*self.FN_KEY + if paths is not None: + # Variable -> list of int + feedback_int = [int(feedback.data[i]) for i in range(feedback.numel())] + for i in range(len(feedback_int)): + if feedback_int[i] == self.TP_KEY: + self.pathdic["TP"].append(paths[i]) + elif feedback_int[i] == self.TN_KEY: + self.pathdic["TN"].append(paths[i]) + elif feedback_int[i] == self.FP_KEY: + self.pathdic["FP"].append(paths[i]) + elif feedback_int[i] == self.FN_KEY: + self.pathdic["FN"].append(paths[i]) + return feedback + + def stats(self): + text = ("Total Positives: {}".format(self.pos), + "Total Negatives: {}".format(self.neg), + "Total TruePos: {}".format(self.true_pos), + "Total TrueNeg: {}".format(self.true_neg), + "Total FalsePos: {}".format(self.false_pos), + "Total FalseNeg: {}".format(self.false_neg)) + return "\n".join(text) + + def accuracy(self): + return (self.true_pos+self.true_neg) / max((self.pos+self.neg), self.EPS) + + def hard_dice(self): + numer = 2 * self.hard_I + self.smooth + denom = self.hard_U + self.smooth + return numer / denom + + def soft_dice(self): + numer = 2 * self.soft_I + self.smooth + denom = self.soft_U + self.smooth + return numer / denom + + def __summarize(self): + self.ACC = self.accuracy() + self.HD = self.hard_dice() + self.SD = self.soft_dice() + + self.P_TPR = self.true_pos / max(self.pos, self.EPS) + self.P_PPV = self.true_pos / max((self.true_pos + self.false_pos), self.EPS) + self.P_F1 = 2*self.true_pos / max((2*self.true_pos + self.false_pos + self.false_neg), self.EPS) + + self.N_TPR = self.true_neg / max(self.neg, self.EPS) + self.N_PPV = self.true_neg / max((self.true_neg + self.false_neg), self.EPS) + self.N_F1 = 2*self.true_neg / max((2*self.true_neg + self.false_neg + self.false_pos), self.EPS) + + def __str__(self): + self.__summarize() + summary = ("Accuracy: {:.4f}".format(self.ACC), + "Hard Dice: {:.4f}".format(self.HD), + "Soft Dice: {:.4f}".format(self.SD), + "For positive class:", + "TP(sensitivity,recall): {:.4f}".format(self.P_TPR), + "PPV(precision): {:.4f}".format(self.P_PPV), + "F-1: {:.4f}".format(self.P_F1), + "", + "For normal class:", + "TP(sensitivity,recall): {:.4f}".format(self.N_TPR), + "PPV(precision): {:.4f}".format(self.N_PPV), + "F-1: {:.4f}".format(self.N_F1) + ) + return "\n".join(summary)