Diff of /utils.py [000000] .. [94d9b6]

Switch to unified view

a b/utils.py
1
import torch
2
import torch.nn as nn
3
from torch.autograd import Variable
4
import torch.optim as optim
5
import torchvision
6
from torchvision import datasets, models
7
from torchvision import transforms as T
8
from torch.utils.data import DataLoader, Dataset
9
import numpy as np
10
import matplotlib.pyplot as plt
11
import os
12
import time
13
import pandas as pd
14
from skimage import io, transform
15
import matplotlib.image as mpimg
16
from PIL import Image
17
from sklearn.metrics import roc_auc_score
18
import torch.nn.functional as F
19
import scipy
20
import random
21
import pickle
22
import scipy.io as sio
23
import itertools
24
from scipy.ndimage.interpolation import shift
25
import copy
26
import warnings
27
warnings.filterwarnings("ignore")
28
plt.ion()
29
30
from dataloader_2d import *
31
from dataloader_3d import *
32
33
train_path = '/beegfs/ark576/new_knee_data/train'
34
val_path = '/beegfs/ark576/new_knee_data/val'
35
test_path = '/beegfs/ark576/new_knee_data/test'
36
37
train_file_names = sorted(pickle.load(open(train_path + '/train_file_names.p','rb')))
38
val_file_names = sorted(pickle.load(open(val_path + '/val_file_names.p','rb')))
39
test_file_names = sorted(pickle.load(open(test_path + '/test_file_names.p','rb')))
40
41
transformed_dataset = {'train': KneeMRIDataset(train_path,train_file_names, train_data= True, flipping=False, normalize= True),
42
                       'validate': KneeMRIDataset(val_path,val_file_names, normalize= True),
43
                       'test': KneeMRIDataset(test_path,test_file_names, normalize= True)
44
                                          }
45
46
dataloader = {x: DataLoader(transformed_dataset[x], batch_size=5,
47
                        shuffle=True, num_workers=0) for x in ['train', 'validate','test']}
48
data_sizes ={x: len(transformed_dataset[x]) for x in ['train', 'validate','test']}
49
50
def plot_hist(hist_dict,hist_type,chart_type = 'semi-log'):
51
    if chart_type == 'log-log':
52
        plt.loglog(range(len(hist_dict['train'])),hist_dict['train'], label='Train ' + hist_type)
53
        plt.loglog(range(len(hist_dict['validate'])),hist_dict['validate'], label = 'Validation ' + hist_type)
54
    if chart_type == 'semi-log':
55
        plt.semilogy(range(len(hist_dict['train'])),hist_dict['train'], label='Train ' + hist_type)
56
        plt.semilogy(range(len(hist_dict['validate'])),hist_dict['validate'], label = 'Validation ' + hist_type)
57
    plt.xlabel('Epochs')
58
    plt.ylabel(hist_type)
59
    plt.legend()
60
    plt.show()
61
62
def dice_loss(true,scores, epsilon = 1e-4,p = 2):
63
    preds = F.softmax(scores)
64
    N, C, sh1, sh2 = true.size()
65
    true = true.view(N, C, -1)
66
    preds = preds.view(N, C, -1)
67
    wts = torch.sum(true, dim = 2) + epsilon
68
    wts = 1/torch.pow(wts,p)
69
    wts = torch.clamp(wts,0,0.1)
70
    wts[wts == 0.1] = 0
71
    wts = wts/(torch.sum(wts,dim = 1)[:,None])
72
    prod = torch.sum(true*preds,dim = 2)
73
    sum_tnp = torch.sum(true + preds, dim = 2)
74
    num = torch.sum(wts * prod, dim = 1)
75
    denom = torch.sum(wts * sum_tnp, dim = 1) + epsilon
76
    loss = 1 - 2*(num/denom)
77
    return torch.mean(loss)
78
79
def dice_loss_2(true,scores, epsilon = 1e-4,p = 2):
80
    preds = F.softmax(scores)
81
    N, C, sh1, sh2 = true.size()
82
    true = true.view(N, C, -1)
83
    preds = preds.view(N, C, -1)
84
    wts = torch.sum(true, dim = 2) + epsilon
85
    wts = 1/torch.pow(wts,p)
86
    wts = torch.clamp(wts,0,0.1)
87
    wts[wts == 0.1] = 1e-6
88
    wts[:,-1] = 1e-15
89
    wts = wts/(torch.sum(wts,dim = 1)[:,None])
90
    prod = torch.sum(true*preds,dim = 2)
91
    sum_tnp = torch.sum(true + preds, dim = 2)
92
    num = torch.sum(wts * prod, dim = 1)
93
    denom = torch.sum(wts * sum_tnp, dim = 1) + epsilon
94
    loss = 1 - 2*(num/denom)
95
    return torch.mean(loss)
96
97
def segments(seg_1, seg_2, seg_3):
98
    seg_tot = seg_1 + seg_2 + seg_3
99
    seg_none = (seg_tot == 0).type(torch.FloatTensor)
100
    seg_all = torch.cat((seg_1.unsqueeze(1),seg_2.unsqueeze(1),seg_3.unsqueeze(1),seg_none.unsqueeze(1)), dim = 1)
101
    return seg_all
102
103
seg_sum = torch.zeros(3)
104
for i, data in enumerate(dataloader['train']):
105
    input, segF, segP, segT,_ = data
106
    seg_sum[0] += torch.sum(segF)
107
    seg_sum[1] += torch.sum(segP)
108
    seg_sum[2] += torch.sum(segT)
109
mean_s_sum = seg_sum/i
110
111
def dice_loss_3(true,scores, epsilon = 1e-4,p = 2, mean=mean_s_sum):
112
    preds = F.softmax(scores)
113
    N, C, sh1, sh2 = true.size()
114
    true = true.view(N, C, -1)
115
    preds = preds.view(N, C, -1)
116
    wts = torch.sum(true, dim = 2) + epsilon
117
    mean = 1/torch.pow(mean,p)
118
    wts[:,:-1] = mean[None].repeat(N,1)
119
    wts[:,-1] = 0
120
    wts = wts/(torch.sum(wts,dim = 1)[:,None])
121
    prod = torch.sum(true*preds,dim = 2)
122
    sum_tnp = torch.sum(true + preds, dim = 2)
123
    num = torch.sum(wts * prod, dim = 1)
124
    denom = torch.sum(wts * sum_tnp, dim = 1) + epsilon
125
    loss = 1 - 2*(num/denom)
126
    return torch.mean(loss)
127
128
def predict(scores,smooth = False,filter_size = 3):
129
    preds = F.softmax(scores)
130
    pred_class = (torch.max(preds, dim = 1)[1])
131
    class_0_pred_seg = (pred_class == 0).type(torch.cuda.FloatTensor)
132
    class_1_pred_seg = (pred_class == 1).type(torch.cuda.FloatTensor)
133
    class_2_pred_seg = (pred_class == 2).type(torch.cuda.FloatTensor)
134
    if smooth:
135
        class_0_pred_seg = F.avg_pool2d(class_0_pred_seg,filter_size,1,int((filter_size-1)/2))>0.5
136
        class_1_pred_seg = F.avg_pool2d(class_1_pred_seg,filter_size,1,int((filter_size-1)/2))>0.5
137
        class_2_pred_seg = F.avg_pool2d(class_2_pred_seg,filter_size,1,int((filter_size-1)/2))>0.5
138
    return class_0_pred_seg.data.type(torch.cuda.FloatTensor), class_1_pred_seg.data.type(torch.cuda.FloatTensor)\
139
, class_2_pred_seg.data.type(torch.cuda.FloatTensor)
140
141
142
def dice_score(true,scores,smooth = False,filter_size = 3, epsilon = 1e-7):
143
    N ,C, sh1, sh2 = true.size()
144
    true = true.view(N,C,-1)
145
    class_0_pred_seg,class_1_pred_seg,class_2_pred_seg = predict(scores, smooth = smooth,filter_size = filter_size)
146
    class_0_pred_seg = class_0_pred_seg.view(N,-1)
147
    class_1_pred_seg = class_1_pred_seg.view(N,-1)
148
    class_2_pred_seg = class_2_pred_seg.view(N,-1)
149
    true = true.data.type(torch.cuda.FloatTensor)
150
    def numerator(truth,pred, idx):
151
        return(torch.sum(truth[:,idx,:] * pred,dim = 1)) + epsilon/2
152
    def denominator(truth,pred,idx):
153
        return(torch.sum(truth[:,idx,:]+pred,dim = 1)) + epsilon
154
    
155
    dice_score_class_0 = torch.mean(2*(numerator(true,class_0_pred_seg,0))/(denominator(true,class_0_pred_seg,0)))
156
    dice_score_class_1 = torch.mean(2*(numerator(true,class_1_pred_seg,1))/(denominator(true,class_1_pred_seg,1)))
157
    dice_score_class_2 = torch.mean(2*(numerator(true,class_2_pred_seg,2))/(denominator(true,class_2_pred_seg,2)))
158
    
159
    return (dice_score_class_0,dice_score_class_1, dice_score_class_2)
160
161
def entropy_loss(true,scores,mean = mean_s_sum,epsilon = 1e-4, p=2):
162
    N,C,sh1,sh2 = true.size()
163
    wts = Variable(torch.zeros(4).cuda()) + epsilon
164
    mean = 1/torch.pow(mean,p)
165
    wts[:-1] = mean
166
    wts[-1] = 1e-9
167
    wts = wts/(torch.sum(wts))
168
    log_prob = F.log_softmax(scores)
169
    prod = (log_prob*true).view(N,C,-1)
170
    prod_t = torch.transpose(prod,1,2)
171
    loss = -torch.mean(prod_t*wts)
172
    return loss
173
174
def image_to_mask(img, femur, patellar, tibia,femur_pr,patellar_pr,tibia_pr,cm = None):
175
    masked_1 = np.ma.masked_where(femur == 0, femur)
176
    masked_2 = np.ma.masked_where(patellar == 0,patellar)
177
    masked_3 = np.ma.masked_where(tibia == 0, tibia)
178
    
179
    masked_1_pr = np.ma.masked_where(femur_pr == 0, femur_pr)
180
    masked_2_pr = np.ma.masked_where(patellar_pr == 0,patellar_pr)
181
    masked_3_pr = np.ma.masked_where(tibia_pr == 0, tibia_pr)
182
    masked_cm = np.ma.masked_where(cm ==-1000,cm)
183
    x = 3
184
    plt.figure(figsize=(20,10))
185
    plt.subplot(1,x,1)
186
    plt.imshow(img, 'gray', interpolation='none')
187
    plt.subplot(1,x,2)
188
    plt.imshow(img, 'gray', interpolation='none')
189
    if np.sum(femur) != 0:
190
        plt.imshow(masked_1, 'spring', interpolation='none', alpha=0.9)
191
    if np.sum(patellar) != 0:
192
        plt.imshow(masked_2, 'coolwarm_r', interpolation='none', alpha=0.9)
193
    if np.sum(tibia) != 0:
194
        plt.imshow(masked_3, 'Wistia', interpolation='none', alpha=0.9)
195
    plt.subplot(1,x,3)
196
    plt.imshow(img, 'gray', interpolation='none')
197
    if np.sum(femur_pr) != 0:
198
        plt.imshow(masked_1_pr, 'spring', interpolation='none', alpha=0.9)
199
    if np.sum(patellar_pr) != 0:
200
        plt.imshow(masked_2_pr, 'coolwarm_r', interpolation='none', alpha=0.9)
201
    if np.sum(tibia_pr) != 0:
202
        plt.imshow(masked_3_pr, 'Wistia', interpolation='none', alpha=0.9)
203
    plt.show()
204
    if cm is not None:
205
        plt.figure(figsize=(20,20))
206
        plt.imshow(masked_cm,'coolwarm_r')
207
        plt.colorbar()
208
        plt.show()
209
210
def generate_noise(true):
211
    return Variable((2*torch.rand(true.size())-1)*0.1).cuda()
212
213
def save_segmentations_2d(model,prediction_models, dataloader,data_sizes,batch_size,phase,model_name,\
214
                          num_samples = 7, smooth = False, filter_size = 3):
215
    y_preds = []
216
    name_list = []
217
    num_samples = num_samples
218
    if phase == 'train':
219
        path = '/beegfs/ark576/Knee Cartilage Data/Train Data/'
220
    if phase == 'validate':
221
        path = '/beegfs/ark576/Knee Cartilage Data/Validation Data/'
222
    if phase == 'test':
223
        path = '/beegfs/ark576/Knee Cartilage Data/Test Data/'
224
    for i in prediction_models:
225
        for param in i.parameters():
226
            param.requires_grad = False
227
    
228
    for i,data in enumerate(dataloader[phase]):
229
        input, segF,segP, segT,variable_name = data
230
        input = Variable(input).cuda()
231
        input_pp = []
232
        for j in prediction_models:
233
            output = j(input)
234
            preds_m = predict_pp(output)
235
            input_pp.append(preds_m)
236
        input_pp = torch.cat(input_pp,dim = 1)
237
        output_pp = model(input_pp)
238
        preds = predict(output_pp,smooth = smooth, filter_size=filter_size)
239
        preds = torch.cat((preds[0][:,None],preds[1][:,None],preds[2][:,None]),dim = 1)
240
        y_preds.append(preds.cpu().numpy())
241
        name_list.append(variable_name)
242
    list_of_names = list(itertools.chain(*name_list))
243
    y_preds = np.concatenate(y_preds).astype(np.uint8)
244
    for i in range(num_samples):
245
        name = list_of_names[i*15][:-3]
246
        pred_segment = y_preds[i*15:(i+1)*15]
247
        file_name = path + name
248
        variable = sio.loadmat(file_name)
249
        temp_variable = {}
250
        temp_variable['MDnr'] = variable['MDnr']
251
        preds_all = pred_segment[[0,1,7,8,9,10,11,12,13,14,2,3,4,5,6],:]
252
        temp_variable['Predicted_segment_F'] = np.transpose(preds_all[:,0,:,:],(1,2,0))
253
        temp_variable['Predicted_segment_P'] = np.transpose(preds_all[:,1,:,:],(1,2,0))
254
        temp_variable['Predicted_segment_T'] = np.transpose(preds_all[:,2,:,:],(1,2,0))
255
        save_path = '/beegfs/ark576/knee-segments/predictions/'+ model_name +'/'+phase+'/'
256
        sio.savemat(save_path+name,temp_variable,appendmat=False, do_compression=True)
257
258
def save_segmentations_2d_prob(model,prediction_models, dataloader,data_sizes,batch_size,phase,model_name,\
259
                          num_samples = 7, smooth = False, filter_size = 3):
260
    y_preds = []
261
    name_list = []
262
    num_samples = num_samples
263
    if phase == 'train':
264
        path = '/beegfs/ark576/Knee Cartilage Data/Train Data/'
265
    if phase == 'validate':
266
        path = '/beegfs/ark576/Knee Cartilage Data/Validation Data/'
267
    if phase == 'test':
268
        path = '/beegfs/ark576/Knee Cartilage Data/Test Data/'
269
    for i in prediction_models:
270
        for param in i.parameters():
271
            param.requires_grad = False
272
    
273
    for i,data in enumerate(dataloader[phase]):
274
        input, segF,segP, segT,variable_name = data
275
        input = Variable(input).cuda()
276
        input_pp = []
277
        for j in prediction_models:
278
            output = j(input)
279
            preds_m = predict_pp(output)
280
            input_pp.append(preds_m)
281
        input_pp = torch.cat(input_pp,dim = 1)
282
        output_pp = model(input_pp)
283
        preds = F.softmax(output_pp)
284
        y_preds.append(preds.data.cpu().numpy())
285
        name_list.append(variable_name)
286
    list_of_names = list(itertools.chain(*name_list))
287
    y_preds = np.concatenate(y_preds)
288
    for i in range(num_samples):
289
        name = list_of_names[i*15][:-3]
290
        pred_segment = y_preds[i*15:(i+1)*15]
291
        file_name = path + name
292
        variable = sio.loadmat(file_name)
293
        temp_variable = {}
294
        temp_variable['NUFnr'] = variable['NUFnr']
295
        temp_variable['GT_F'] = variable['SegmentationF']
296
        temp_variable['GT_P'] = variable['SegmentationP']
297
        temp_variable['GT_T'] = variable['SegmentationT']
298
        preds_all = pred_segment[[0,1,7,8,9,10,11,12,13,14,2,3,4,5,6],:]
299
        temp_variable['Predicted_prob'] = preds_all
300
        save_path = '/beegfs/ark576/knee-segments/predictions/'+ model_name +'/'+phase+'/'
301
        sio.savemat(save_path+name+'_prob',temp_variable,appendmat=False, do_compression=True)
302
303
from sklearn.metrics import confusion_matrix
304
def dice_score_image(pred,true,epsilon = 1e-5):
305
    num = 2*np.sum(pred*true) + epsilon
306
    pred_norm = np.sum(pred)
307
    true_norm = np.sum(true)
308
    if pred_norm == 0 or true_norm == 0:
309
        return None
310
    else:
311
        denom = pred_norm + true_norm + epsilon
312
        return num/denom
313
314
def save_segmentations_3D(model, dataloader,data_sizes,batch_size,phase,model_name, num_samples = 7):
315
    y_preds = []
316
    name_list = []
317
    num_samples = num_samples
318
    if phase == 'train':
319
        path = '/beegfs/ark576/Knee Cartilage Data/Train Data/'
320
    if phase == 'validate':
321
        path = '/beegfs/ark576/Knee Cartilage Data/Validation Data/'
322
    if phase == 'test':
323
        path = '/beegfs/ark576/Knee Cartilage Data/Test Data/'
324
    for data in dataloader[phase]:
325
        input, segments, variable_name = data
326
        input = Variable(input).cuda()
327
        output = model(input)
328
        output_reshaped = torch.transpose(output,2,1).contiguous().view(-1,4,256,256)
329
        preds = predict(output_reshaped)
330
        preds = torch.cat((preds[0][:,None],preds[1][:,None],preds[2][:,None]),dim = 1)
331
        y_preds.append(preds.cpu().numpy())
332
        name_list.append(variable_name)
333
    list_of_names = list(itertools.chain(*name_list))
334
    y_preds = np.concatenate(y_preds).astype(np.uint8)
335
    for i in range(num_samples):
336
        name = list_of_names[i]
337
        preds_all = y_preds[i*15:(i+1)*15]
338
        file_name = path + name
339
        variable = sio.loadmat(file_name)
340
        temp_variable = {}
341
        temp_variable['MDnr'] = variable['MDnr']
342
        temp_variable['Predicted_segment_F'] = np.transpose(preds_all[:,0,:,:],(1,2,0))
343
        temp_variable['Predicted_segment_T'] = np.transpose(preds_all[:,1,:,:],(1,2,0))
344
        temp_variable['Predicted_segment_P'] = np.transpose(preds_all[:,2,:,:],(1,2,0))
345
        save_path = '/beegfs/ark576/knee-segments/predictions/'+ model_name +'/'+phase+'/'
346
        sio.savemat(save_path+name,temp_variable,appendmat=False, do_compression=True)
347
348
def make_certainity_maps(scores):
349
    probs = F.softmax(scores)
350
    pred_prob,idx = (torch.max(probs, dim = 1))
351
    pred_prob_c = torch.clamp(pred_prob,0.00001,0.999999)
352
    ret_value = torch.log(pred_prob_c) - torch.log((1-pred_prob_c))
353
    ret_value[idx==3]=-1000
354
    return ret_value
355
356