Diff of /utils.py [000000] .. [9ff54e]

Switch to unified view

a b/utils.py
1
from collections import defaultdict
2
from os.path import join
3
from random import randint
4
from scipy import ndimage
5
from statistics import median
6
import numpy
7
import os
8
import shutil
9
import sys
10
11
from torch import nn
12
import torch
13
import nibabel as nib
14
15
16
def transfer_weights(target_model, saved_model):
17
    """
18
    target_model: a model instance whose weight params are to be overwritten
19
    saved_model: a model whose weight params will be transfered to target.
20
        saved_model can be a string(path to a snapshot), an instance of model
21
        or a state dict of a model
22
    """
23
    target_dict = target_model.state_dict()
24
    if isinstance(saved_model, str):
25
        source_dict = torch.load(saved_model)
26
    else:
27
        source_dict = saved_model
28
    if not isinstance(source_dict, dict):
29
        source_dict = source_dict.state_dict()
30
    source_dict = {k: v for k, v in source_dict.items() if
31
                   k in target_model.state_dict() and source_dict[k].size() == target_model.state_dict()[k].size()}
32
    target_dict.update(source_dict)
33
    target_model.load_state_dict(target_dict)
34
35
36
def generate_ex_list(directory):
37
    """
38
    Generate list of MRI objects
39
    """
40
    inputs = []
41
    labels = []
42
    for dirpath, dirs, files in os.walk(directory):
43
        label_list = list()
44
        for file in files:
45
            if not file.startswith('.') and file.endswith('.nii.gz'):
46
                if ("Lesion" in file):
47
                    label_list.append(join(dirpath, file))
48
                elif ("mask" not in file):
49
                    inputs.append(join(dirpath, file))
50
        if label_list:
51
            labels.append(label_list)
52
53
    return inputs, labels
54
55
56
def gen_mask(lesion_files):
57
    """
58
    Given a list of lesion files, generate a mask
59
    that incorporates data from all of them
60
    """
61
    first_lesion = nib.load(lesion_files[0]).get_data()
62
    if len(lesion_files) == 1:
63
        return first_lesion
64
    lesion_data = numpy.zeros((first_lesion.shape[0], first_lesion.shape[1], first_lesion.shape[2]))
65
    for file in lesion_files:
66
        l_file = correct_dims(nib.load(file).get_data())
67
        lesion_data = numpy.maximum(l_file, lesion_data)
68
    return lesion_data
69
70
71
def correct_dims(img):
72
    """
73
    Fix the dimension of the image, if necessary
74
    """
75
    if len(img.shape) > 3:
76
        img = img.reshape(img.shape[0], img.shape[1], img.shape[2])
77
    return img
78
79
80
def get_weight_vector(labels, weight, is_cuda):
81
    """ Generates the weight vector for BCE loss
82
    You can only control positive weight, and negative weight is
83
    default to 1.
84
    So if ratio of positive and negative samples are 1:3,
85
    then give weight 3, and this functio returns 3 for positive and
86
    1 for negative samples.
87
    """
88
    if is_cuda:
89
        labels = labels.cpu()
90
    labels = labels.data.numpy()
91
    labels = labels * (weight-1) + 1
92
    weight_label = torch.from_numpy(labels).type(torch.FloatTensor)
93
    if is_cuda:
94
        weight_label = weight_label.cuda()
95
    return weight_label
96
97
98
def resize_img(input_img, label_img, size):
99
    """
100
    size: int or list of int
101
        when it's a list, it should include x, y, z values
102
    Resize image to (size x size x size)
103
    """
104
    if isinstance(size, int):
105
        size = [size]*3
106
    assert len(size) == 3
107
    ax1 = float(size[0]) / input_img.shape[0]
108
    ax2 = float(size[1]) / input_img.shape[1]
109
    ax3 = float(size[2]) / input_img.shape[2]
110
    ex = ndimage.zoom(input_img, (ax1, ax2, ax3))
111
    label = ndimage.zoom(label_img, (ax1, ax2, ax3))
112
    return ex, label
113
114
115
def center_crop(input_img, label_img, size):
116
    """
117
    Crop center section from image
118
    size: int or list of int
119
        when it's a list, it should include x, y, z values
120
    Use for testing.
121
    """
122
    if isinstance(size, int):
123
        size = [size]*3
124
    assert len(size) == 3
125
    coords = [0]*3
126
    for i in range(3):
127
        coords[i] = int((input_img.shape[i]-size[i])//2)
128
    x, y, z = coords
129
    ex = input_img[x:x+size[0], y:y+size[1], z:z+size[2]]
130
    label = label_img[x:x+size[0], y:y+size[1], z:z+size[2]]
131
    return ex, label
132
133
134
def find_and_crop_lesions(input_img, label_img, size, deterministic=False):
135
    """
136
    Find and crop image based on center of lesions
137
    size: int or list of int
138
        when it's a list, it should include x, y, z values
139
    Use for validation.
140
    """
141
    if isinstance(size, int):
142
        size = [size]*3
143
    assert len(size) == 3
144
    nonzeros = label_img.nonzero()
145
    d = [0]*3
146
    if not deterministic:
147
        for i in range(3):
148
            d[i] = randint(-size[i]//4, size[i]//4)
149
150
    coords = [0]*3
151
    for i in range(3):
152
        coords[i] = max(min(int(median(nonzeros[i])) - (size[i] // 2) + d[i], input_img.shape[i] - size[i] - 1), 0)
153
    x, y, z = coords
154
    ex = input_img[x:x+size[0], y:y+size[1], z:z+size[2]]
155
    label = label_img[x:x+size[0], y:y+size[1], z:z+size[2]]
156
    return ex, label
157
158
159
def random_crop(input_img, label_img, size, remove_background=False):
160
    """
161
    Crop random section from image
162
    size: int or list of int
163
        when it's a list, it should include x, y, z values
164
    remove_background: boolean
165
        use this option when input contains larger background or crop size is very small
166
    Use for training
167
    """
168
    if isinstance(size, int):
169
        size = [size]*3
170
    assert len(size) == 3
171
    non_zero_percentage = 0
172
    while non_zero_percentage < 0.7:
173
        """draw x,y,z coords
174
        """
175
        coords = [0]*3
176
        for i in range(3):
177
            coords[i] = numpy.random.choice(input_img.shape[i] - size[i])
178
        x, y, z = coords
179
        ex = input_img[x:x+size[0], y:y+size[1], z:z+size[2]]
180
        non_zero_percentage = numpy.count_nonzero(ex) / float(size[0]*size[1]*size[2])
181
        if not remove_background:
182
            break
183
        if non_zero_percentage < 0.7:
184
            del ex
185
186
    label = label_img[x:x+size[0], y:y+size[1], z:z+size[2]]
187
    return ex, label
188
189
190
class Report:
191
    EPS = sys.float_info.epsilon
192
    TP_KEY = 0
193
    TN_KEY = 1
194
    FP_KEY = 2
195
    FN_KEY = 3
196
197
    def __init__(self, threshold=0.5, smooth=sys.float_info.epsilon, apply_square=False, need_feedback=False):
198
        """
199
        apply_square: use squared elements in the denominator of soft Dice
200
        need_feedback: returns a tensor storing KEYS(0 to 3) for each output element
201
        """
202
        self.pos = 0
203
        self.neg = 0
204
        self.false_pos = 0
205
        self.false_neg = 0
206
        self.true_pos = 0
207
        self.true_neg = 0
208
        self.soft_I = 0
209
        self.soft_U = 0
210
        self.hard_I = 0
211
        self.hard_U = 0
212
        self.smooth = smooth
213
        self.apply_square = apply_square  # this variable: mainly for testing
214
        self.need_feedback = need_feedback
215
        self.threshold = threshold
216
        self.pathdic = defaultdict(list)
217
218
    def feed(self, pred, label, paths=None):
219
        """ pred size: batch x dim1 x dim2 x...
220
            label size: batch x dim1 x dim2 x...
221
            First dim should be a batch size
222
        """
223
        self.soft_I += (pred * label).sum().item()
224
        power_coeff = 2 if self.apply_square else 1
225
        if power_coeff == 1:
226
            self.soft_U += (pred.sum() + label.sum()).item()
227
        else:
228
            self.soft_U += (pred.pow(power_coeff).sum() + label.pow(power_coeff).sum()).item()
229
        pred = pred.view(-1)
230
        label = label.view(-1)
231
        pred = (pred > self.threshold).squeeze()
232
        not_pred = (pred == 0).squeeze()
233
        label = label.byte().squeeze()
234
        not_label = (label == 0).squeeze()
235
        self.pos += label.sum().item()
236
        self.neg += not_label.sum().item()
237
        pxl = pred * label
238
        self.hard_I += (pxl).sum().item()
239
        self.hard_U += (pred.sum() + label.sum()).item()
240
        pxnl = pred * not_label
241
        fp = (pxnl).sum().item()
242
        self.false_pos += fp
243
        npxl = not_pred * label
244
        fn = (npxl).sum().item()
245
        self.false_neg += fn
246
        tp = (pxl).sum().item()
247
        self.true_pos += tp
248
        npxnl = not_pred * not_label
249
        tn = (npxnl).sum().item()
250
        self.true_neg += tn
251
252
        feedback = None
253
        if self.need_feedback:
254
            feedback = pxl*self.TP_KEY +\
255
                npxnl*self.TN_KEY +\
256
                pxnl*self.FP_KEY +\
257
                npxl*self.FN_KEY
258
            if paths is not None:
259
                # Variable -> list of int
260
                feedback_int = [int(feedback.data[i]) for i in range(feedback.numel())]
261
                for i in range(len(feedback_int)):
262
                    if feedback_int[i] == self.TP_KEY:
263
                        self.pathdic["TP"].append(paths[i])
264
                    elif feedback_int[i] == self.TN_KEY:
265
                        self.pathdic["TN"].append(paths[i])
266
                    elif feedback_int[i] == self.FP_KEY:
267
                        self.pathdic["FP"].append(paths[i])
268
                    elif feedback_int[i] == self.FN_KEY:
269
                        self.pathdic["FN"].append(paths[i])
270
        return feedback
271
272
    def stats(self):
273
        text = ("Total Positives: {}".format(self.pos),
274
                "Total Negatives: {}".format(self.neg),
275
                "Total TruePos: {}".format(self.true_pos),
276
                "Total TrueNeg: {}".format(self.true_neg),
277
                "Total FalsePos: {}".format(self.false_pos),
278
                "Total FalseNeg: {}".format(self.false_neg))
279
        return "\n".join(text)
280
281
    def accuracy(self):
282
        return (self.true_pos+self.true_neg) / max((self.pos+self.neg), self.EPS)
283
284
    def hard_dice(self):
285
        numer = 2 * self.hard_I + self.smooth
286
        denom = self.hard_U + self.smooth
287
        return numer / denom
288
289
    def soft_dice(self):
290
        numer = 2 * self.soft_I + self.smooth
291
        denom = self.soft_U + self.smooth
292
        return numer / denom
293
294
    def __summarize(self):
295
        self.ACC = self.accuracy()
296
        self.HD = self.hard_dice()
297
        self.SD = self.soft_dice()
298
299
        self.P_TPR = self.true_pos / max(self.pos, self.EPS)
300
        self.P_PPV = self.true_pos / max((self.true_pos + self.false_pos), self.EPS)
301
        self.P_F1 = 2*self.true_pos / max((2*self.true_pos + self.false_pos + self.false_neg), self.EPS)
302
303
        self.N_TPR = self.true_neg / max(self.neg, self.EPS)
304
        self.N_PPV = self.true_neg / max((self.true_neg + self.false_neg), self.EPS)
305
        self.N_F1 = 2*self.true_neg / max((2*self.true_neg + self.false_neg + self.false_pos), self.EPS)
306
307
    def __str__(self):
308
        self.__summarize()
309
        summary = ("Accuracy: {:.4f}".format(self.ACC),
310
                   "Hard Dice: {:.4f}".format(self.HD),
311
                   "Soft Dice: {:.4f}".format(self.SD),
312
                   "For positive class:",
313
                   "TP(sensitivity,recall): {:.4f}".format(self.P_TPR),
314
                   "PPV(precision): {:.4f}".format(self.P_PPV),
315
                   "F-1: {:.4f}".format(self.P_F1),
316
                   "",
317
                   "For normal class:",
318
                   "TP(sensitivity,recall): {:.4f}".format(self.N_TPR),
319
                   "PPV(precision): {:.4f}".format(self.N_PPV),
320
                   "F-1: {:.4f}".format(self.N_F1)
321
                   )
322
        return "\n".join(summary)