Diff of /utils.py [000000] .. [7b5b9f]

Switch to unified view

a b/utils.py
1
""" Utility functions. """
2
import numpy as np
3
import os
4
import random
5
import tensorflow as tf
6
7
from tensorflow.contrib.layers.python import layers as tf_layers
8
from tensorflow.python.platform import flags
9
import SimpleITK as sitk
10
from scipy import ndimage
11
import itertools
12
from tensorflow.contrib import slim
13
from scipy.ndimage import _ni_support
14
from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\
15
    generate_binary_structure
16
FLAGS = flags.FLAGS
17
18
## Image reader
19
def get_images(paths, labels, nb_samples=None, shuffle=True):
20
    if nb_samples is not None:
21
        sampler = lambda x: random.sample(x, nb_samples)
22
    else:
23
        sampler = lambda x: x
24
    images = [(i, os.path.join(path, image)) \
25
        for i, path in zip(labels, paths) \
26
        for image in sampler(os.listdir(path))]
27
    if shuffle:
28
        random.shuffle(images)
29
    return images
30
 
31
## Loss functions
32
def mse(pred, label):
33
    pred = tf.reshape(pred, [-1])
34
    label = tf.reshape(label, [-1])
35
    return tf.reduce_mean(tf.square(pred-label))
36
37
def xent(pred, label):
38
    return tf.nn.softmax_cross_entropy_with_logits_v2(logits=pred, labels=label)
39
40
def kd(data1, label1, data2, label2, bool_indicator, n_class=7, temperature=2.0):
41
42
    kd_loss = 0.0
43
    eps = 1e-16
44
45
    prob1s = []
46
    prob2s = []
47
48
    for cls in range(n_class):
49
        mask1 = tf.tile(tf.expand_dims(label1[:, cls], -1), [1, n_class])
50
        logits_sum1 = tf.reduce_sum(tf.multiply(data1, mask1), axis=0)
51
        num1 = tf.reduce_sum(label1[:, cls])
52
        activations1 = logits_sum1 * 1.0 / (num1 + eps) # add eps for prevent un-sampled class resulting in NAN
53
        prob1 = tf.nn.softmax(activations1 / temperature)
54
        prob1 = tf.clip_by_value(prob1, clip_value_min=1e-8, clip_value_max=1.0)  # for preventing prob=0 resulting in NAN
55
56
        mask2 = tf.tile(tf.expand_dims(label2[:, cls], -1), [1, n_class])
57
        logits_sum2 = tf.reduce_sum(tf.multiply(data2, mask2), axis=0)
58
        num2 = tf.reduce_sum(label2[:, cls])
59
        activations2 = logits_sum2 * 1.0 / (num2 + eps)
60
        prob2 = tf.nn.softmax(activations2 / temperature)
61
        prob2 = tf.clip_by_value(prob2, clip_value_min=1e-8, clip_value_max=1.0)
62
63
        KL_div = (tf.reduce_sum(prob1 * tf.log(prob1 / prob2)) + tf.reduce_sum(prob2 * tf.log(prob2 / prob1))) / 2.0
64
        kd_loss += KL_div * bool_indicator[cls]
65
66
        prob1s.append(prob1)
67
        prob2s.append(prob2)
68
69
    kd_loss = kd_loss / n_class
70
71
    return kd_loss, prob1s, prob2s
72
73
def JS(data1, label1, data2, label2, bool_indicator, n_class=7, temperature=2.0):
74
75
    kd_loss = 0.0
76
    eps = 1e-16
77
78
    prob1s = []
79
    prob2s = []
80
81
    for cls in range(n_class):
82
        mask1 = tf.tile(tf.expand_dims(label1[:, cls], -1), [1, n_class])
83
        logits_sum1 = tf.reduce_sum(tf.multiply(data1, mask1), axis=0)
84
        num1 = tf.reduce_sum(label1[:, cls])
85
        activations1 = logits_sum1 * 1.0 / (num1 + eps) # add eps for prevent un-sampled class resulting in NAN
86
        prob1 = tf.nn.softmax(activations1 / temperature)
87
        prob1 = tf.clip_by_value(prob1, clip_value_min=1e-8, clip_value_max=1.0)  # for preventing prob=0 resulting in NAN
88
89
        mask2 = tf.tile(tf.expand_dims(label2[:, cls], -1), [1, n_class])
90
        logits_sum2 = tf.reduce_sum(tf.multiply(data2, mask2), axis=0)
91
        num2 = tf.reduce_sum(label2[:, cls])
92
        activations2 = logits_sum2 * 1.0 / (num2 + eps)
93
        prob2 = tf.nn.softmax(activations2 / temperature)
94
        prob2 = tf.clip_by_value(prob2, clip_value_min=1e-8, clip_value_max=1.0)
95
96
        mean_prob = (prob1 + prob2) / 2
97
98
        JS_div = (tf.reduce_sum(prob1 * tf.log(prob1 / mean_prob)) + tf.reduce_sum(prob2 * tf.log(prob2 / mean_prob))) / 2.0
99
        kd_loss += JS_div * bool_indicator[cls]
100
101
        prob1s.append(prob1)
102
        prob2s.append(prob2)
103
104
    kd_loss = kd_loss / n_class
105
106
    return kd_loss, prob1s, prob2s
107
108
def contrastive(feature1, label1, feature2, label2, bool_indicator=None, margin=50):
109
110
    l1 = tf.argmax(label1, axis=1)
111
    l2 = tf.argmax(label2, axis=1)
112
    pair = tf.to_float(tf.equal(l1,l2))
113
114
    delta = tf.reduce_sum(tf.square(feature1-feature2), 1) + 1e-10
115
    match_loss = delta
116
117
    delta_sqrt = tf.sqrt(delta + 1e-10)
118
    mismatch_loss = tf.square(tf.nn.relu(margin - delta_sqrt))
119
120
    if bool_indicator is None:
121
        loss = tf.reduce_mean(0.5 * (pair * match_loss + (1-pair) * mismatch_loss))
122
    else:
123
        loss = 0.5 * tf.reduce_sum(match_loss*pair)/tf.reduce_sum(pair)
124
125
    debug_dist_positive = tf.reduce_sum(delta_sqrt * pair)/tf.reduce_sum(pair)
126
    debug_dist_negative = tf.reduce_sum(delta_sqrt * (1-pair))/tf.reduce_sum(1-pair)
127
128
    return loss, pair, delta, debug_dist_positive, debug_dist_negative
129
130
def compute_distance(feature1, label1, feature2, label2):
131
    l1 = tf.argmax(label1, axis=1)
132
    l2 = tf.argmax(label2, axis=1)
133
    pair = tf.to_float(tf.equal(l1,l2))
134
135
    delta = tf.reduce_sum(tf.square(feature1-feature2), 1)
136
    delta_sqrt = tf.sqrt(delta + 1e-16)
137
138
    dist_positive_pair = tf.reduce_sum(delta_sqrt * pair)/tf.reduce_sum(pair)
139
    dist_negative_pair = tf.reduce_sum(delta_sqrt * (1-pair))/tf.reduce_sum(1-pair)
140
141
    return dist_positive_pair, dist_negative_pair
142
143
def _get_segmentation_cost(softmaxpred, seg_gt, n_class=2):
144
    """
145
    calculate the loss for segmentation prediction
146
    :param seg_logits: probability segmentation from the segmentation network
147
    :param seg_gt: ground truth segmentaiton mask
148
    :return: segmentation loss, according to the cost_kwards setting, cross-entropy weighted loss and dice loss
149
    """
150
    dice = 0
151
152
    for i in xrange(n_class):
153
        #inse = tf.reduce_sum(softmaxpred[:, :, :, i]*seg_gt[:, :, :, i])
154
        inse = tf.reduce_sum(softmaxpred[:, :, :, i]*seg_gt[:, :, :, i])
155
        l = tf.reduce_sum(softmaxpred[:, :, :, i])
156
        r = tf.reduce_sum(seg_gt[:, :, :, i])
157
        dice += 2.0 * inse/(l+r+1e-7) # here 1e-7 is relaxation eps
158
    dice_loss = 1 - 1.0 * dice / n_class
159
160
    # ce_weighted = 0
161
    # for i in xrange(n_class):
162
    #     gti = seg_gt[:,:,:,i]
163
    #     predi = softmaxpred[:,:,:,i]
164
    #     ce_weighted += -1.0 * gti * tf.log(tf.clip_by_value(predi, 0.005, 1))
165
    # ce_weighted_loss = tf.reduce_mean(ce_weighted)
166
167
    # total_loss =  dice_loss 
168
169
170
    return dice_loss#, dice_loss, ce_weighted_loss
171
172
def _get_compactness_cost(y_pred, y_true): 
173
174
    """
175
    y_pred: BxHxWxC
176
    """
177
    """
178
    lenth term
179
    """
180
181
    # y_pred = tf.one_hot(y_pred, depth=2)
182
    # print (y_true.shape)
183
    # print (y_pred.shape)
184
    y_pred = y_pred[..., 1]
185
    y_true = y_pred[..., 1]
186
187
    x = y_pred[:,1:,:] - y_pred[:,:-1,:] # horizontal and vertical directions 
188
    y = y_pred[:,:,1:] - y_pred[:,:,:-1]
189
190
    delta_x = x[:,:,1:]**2
191
    delta_y = y[:,1:,:]**2
192
193
    delta_u = tf.abs(delta_x + delta_y) 
194
195
    epsilon = 0.00000001 # where is a parameter to avoid square root is zero in practice.
196
    w = 0.01
197
    length = w * tf.reduce_sum(tf.sqrt(delta_u + epsilon), [1, 2])
198
199
    area = tf.reduce_sum(y_pred, [1,2])
200
201
    compactness_loss = tf.reduce_sum(length ** 2 / (area * 4 * 3.1415926))
202
203
    return compactness_loss, tf.reduce_sum(length), tf.reduce_sum(area), delta_u
204
205
# def _get_sample_masf(y_true):
206
#     """
207
#     y_pred: BxHxWx2
208
#     """
209
#     positive_mask = np.expand_dims(y_true[..., 1], axis=3)
210
#     metrix_label_group = np.expand_dims(np.array([1, 0, 1, 1, 0]), axis = 1)
211
#     # print (positive_mask.shape)
212
#     coutour_group = np.zeros(positive_mask.shape)
213
214
#     for i in range(positive_mask.shape[0]):
215
#         slice_i = positive_mask[i]
216
        
217
#         if metrix_label_group[i] == 1:
218
#             sample = (slice_i == 1)
219
#         elif metrix_label_group[i] == 0:
220
#             sample = (slice_i == 0)
221
222
#         coutour_group[i] = sample
223
224
#     return coutour_group, metrix_label_group
225
226
def _get_coutour_sample(y_true):
227
    """
228
    y_true: BxHxWx2
229
    """
230
    positive_mask = np.expand_dims(y_true[..., 1], axis=3)
231
    metrix_label_group = np.expand_dims(np.array([1, 0, 1, 1, 0]), axis = 1)
232
    coutour_group = np.zeros(positive_mask.shape)
233
234
    for i in range(positive_mask.shape[0]):
235
        slice_i = positive_mask[i]
236
237
        if metrix_label_group[i] == 1:
238
            # generate coutour mask
239
            erosion = ndimage.binary_erosion(slice_i[..., 0], iterations=1).astype(slice_i.dtype)
240
            sample = np.expand_dims(slice_i[..., 0] - erosion, axis = 2)
241
242
        elif metrix_label_group[i] == 0:
243
            # generate background mask
244
            dilation = ndimage.binary_dilation(slice_i, iterations=5).astype(slice_i.dtype)
245
            sample = dilation - slice_i 
246
247
        coutour_group[i] = sample
248
    return coutour_group, metrix_label_group
249
250
# def _get_negative(y_true):
251
def _get_boundary_cost(y_pred, y_true): 
252
253
    """
254
    y_pred: BxHxWxC
255
    """
256
    """
257
    lenth term
258
    """
259
260
    # y_pred = tf.one_hot(y_pred, depth=2)
261
    # print (y_true.shape)
262
    # print (y_pred.shape)
263
    y_pred = y_pred[..., 1]
264
    y_true = y_pred[..., 1]
265
266
    x = y_pred[:,1:,:] - y_pred[:,:-1,:] # horizontal and vertical directions 
267
    y = y_pred[:,:,1:] - y_pred[:,:,:-1]
268
269
    delta_x = x[:,:,1:]**2
270
    delta_y = y[:,1:,:]**2
271
272
    delta_u = tf.abs(delta_x + delta_y) 
273
274
    epsilon = 0.00000001 # where is a parameter to avoid square root is zero in practice.
275
    w = 0.01
276
    length = w * tf.reduce_sum(tf.sqrt(delta_u + epsilon), [1, 2]) # equ.(11) in the paper
277
278
    area = tf.reduce_sum(y_pred, [1,2])
279
280
    compactness_loss = tf.reduce_sum(length ** 2 / (area * 4 * 3.1415926))
281
282
    return compactness_loss, tf.reduce_sum(length), tf.reduce_sum(area)
283
284
def check_folder(log_dir):
285
    if not os.path.exists(log_dir):
286
        print ("Allocating '{:}'".format(log_dir))
287
        os.makedirs(log_dir)
288
    return log_dir
289
290
def _eval_dice(gt_y, pred_y, detail=False):
291
292
    class_map = {  # a map used for mapping label value to its name, used for output
293
        "0": "bg",
294
        "1": "CZ",
295
        "2": "prostate"
296
    }
297
298
    dice = []
299
300
    for cls in xrange(1,2):
301
302
        gt = np.zeros(gt_y.shape)
303
        pred = np.zeros(pred_y.shape)
304
305
        gt[gt_y == cls] = 1
306
        pred[pred_y == cls] = 1
307
308
        dice_this = 2*np.sum(gt*pred)/(np.sum(gt)+np.sum(pred))
309
        dice.append(dice_this)
310
311
        if detail is True:
312
            #print ("class {}, dice is {:2f}".format(class_map[str(cls)], dice_this))
313
            logging.info("class {}, dice is {:2f}".format(class_map[str(cls)], dice_this))
314
    return dice
315
316
def __surface_distances(result, reference, voxelspacing=None, connectivity=1):
317
    """
318
    The distances between the surface voxel of binary objects in result and their
319
    nearest partner surface voxel of a binary object in reference.
320
    """
321
    result = np.atleast_1d(result.astype(np.bool))
322
    reference = np.atleast_1d(reference.astype(np.bool))
323
    if voxelspacing is not None:
324
        voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim)
325
        voxelspacing = np.asarray(voxelspacing, dtype=np.float64)
326
        if not voxelspacing.flags.contiguous:
327
            voxelspacing = voxelspacing.copy()
328
            
329
    # binary structure
330
    footprint = generate_binary_structure(result.ndim, connectivity)
331
    
332
    # test for emptiness
333
    if 0 == np.count_nonzero(result): 
334
        raise RuntimeError('The first supplied array does not contain any binary object.')
335
    if 0 == np.count_nonzero(reference): 
336
        raise RuntimeError('The second supplied array does not contain any binary object.')    
337
            
338
    # extract only 1-pixel border line of objects
339
    result_border = result ^ binary_erosion(result, structure=footprint, iterations=1)
340
    reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1)
341
    
342
    # compute average surface distance        
343
    # Note: scipys distance transform is calculated only inside the borders of the
344
    #       foreground objects, therefore the input has to be reversed
345
    dt = distance_transform_edt(~reference_border, sampling=voxelspacing)
346
    sds = dt[result_border]
347
    
348
    return sds
349
350
def asd(result, reference, voxelspacing=None, connectivity=1):
351
  
352
    sds = __surface_distances(result, reference, voxelspacing, connectivity)
353
    asd = sds.mean()
354
    return asd
355
356
def calculate_hausdorff(lP,lT,spacing):
357
358
    return asd(lP, lT, spacing)
359
360
def _eval_haus(pred, gt, spacing, detail=False):
361
    '''
362
    :param pred: whole brain prediction
363
    :param gt: whole
364
    :param detail:
365
    :return: a list, indicating Dice of each class for one case
366
    '''
367
    haus = []
368
369
    for cls in range(1,2):
370
        pred_i = np.zeros(pred.shape)
371
        pred_i[pred == cls] = 1
372
        gt_i = np.zeros(gt.shape)
373
        gt_i[gt == cls] = 1
374
375
        # hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
376
        # hausdorff_distance_filter.Execute(gt_i, pred_i)
377
378
        haus_cls = calculate_hausdorff(gt_i, (pred_i), spacing)
379
380
        haus.append(haus_cls)
381
382
        if detail is True:
383
            logging.info("class {}, haus is {:4f}".format(class_map[str(cls)], haus_cls))
384
    # logging.info("4 class average haus is {:4f}".format(np.mean(haus)))
385
386
    return haus
387
388
def _connectivity_region_analysis(mask):
389
    s = [[0,1,0],
390
         [1,1,1],
391
         [0,1,0]]
392
    label_im, nb_labels = ndimage.label(mask)#, structure=s)
393
394
    sizes = ndimage.sum(mask, label_im, range(nb_labels + 1))
395
396
    # plt.imshow(label_im)        
397
    label_im[label_im != np.argmax(sizes)] = 0
398
    label_im[label_im == np.argmax(sizes)] = 1
399
400
    return label_im
401
402
def _crop_object_region(mask, prediction):
403
404
    limX, limY, limZ = np.where(mask>0)
405
    min_z = np.min(limZ)
406
    max_z = np.max(limZ)
407
408
    prediction[..., :np.min(limZ)] = 0
409
    prediction[..., np.max(limZ)+1:] = 0
410
411
    return prediction
412
413
def parse_fn(data_path):
414
    '''
415
    :param image_path: path to a folder of a patient
416
    :return: normalized entire image with its corresponding label
417
    In an image, the air region is 0, so we only calculate the mean and std within the brain area
418
    For any image-level normalization, do it here
419
    '''
420
    path = data_path.split(",")
421
    image_path = path[0]
422
    label_path = path[1]
423
    #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
424
    #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
425
    itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
426
    itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
427
    # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz'))
428
429
    image = sitk.GetArrayFromImage(itk_image)
430
    mask = sitk.GetArrayFromImage(itk_mask)
431
    #image[image >= 1000] = 1000
432
    binary_mask = np.ones(mask.shape)
433
    mean = np.sum(image * binary_mask) / np.sum(binary_mask)
434
    std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask))
435
    image = (image - mean) / std  # normalize per image, using statistics within the brain, but apply to whole image
436
437
    mask[mask==2] = 1
438
439
    return image.transpose([1,2,0]), mask.transpose([1,2,0]) # transpose the orientation of the
440
441
442
def parse_fn_haus(data_path):
443
    '''
444
    :param image_path: path to a folder of a patient
445
    :return: normalized entire image with its corresponding label
446
    In an image, the air region is 0, so we only calculate the mean and std within the brain area
447
    For any image-level normalization, do it here
448
    '''
449
    path = data_path.split(",")
450
    image_path = path[0]
451
    label_path = path[1]
452
    #itk_image = zoom2shape(image_path, [512,512])#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
453
    #itk_mask = zoom2shape(label_path, [512,512], label=True)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
454
    itk_image = sitk.ReadImage(image_path)#os.path.join(image_path, 'T1_unbiased_brain_rigid_to_mni.nii.gz'))
455
    itk_mask = sitk.ReadImage(label_path)#os.path.join(image_path, 'T1_brain_seg_rigid_to_mni.nii.gz'))
456
    # itk_image = sitk.ReadImage(os.path.join(image_path, 'T2_FLAIR_unbiased_brain_rigid_to_mni.nii.gz'))
457
    spacing = itk_mask.GetSpacing()
458
459
    image = sitk.GetArrayFromImage(itk_image)
460
    mask = sitk.GetArrayFromImage(itk_mask)
461
    #image[image >= 1000] = 1000
462
    binary_mask = np.ones(mask.shape)
463
    mean = np.sum(image * binary_mask) / np.sum(binary_mask)
464
    std = np.sqrt(np.sum(np.square(image - mean) * binary_mask) / np.sum(binary_mask))
465
    image = (image - mean) / std  # normalize per image, using statistics within the brain, but apply to whole image
466
467
    mask[mask==2] = 1
468
469
    return image.transpose([1,2,0]), mask.transpose([1,2,0]), spacing
470
471
def show_all_variables():
472
    model_vars = tf.trainable_variables()
473
    slim.model_analyzer.analyze_vars(model_vars, print_info=True)
474