a b/utils.py
1
#!/usr/bin/python
2
import numpy as np
3
from keras.models import *
4
from keras.layers import *
5
from keras.optimizers import *
6
from keras.callbacks import *
7
from keras.losses import *
8
from keras.preprocessing.image import *
9
from os.path import isfile
10
from tqdm import tqdm
11
import random
12
from glob import glob
13
import skimage.io as io
14
import skimage.transform as tr
15
import skimage.morphology as mo
16
import SimpleITK as sitk
17
from pushover import Client
18
import matplotlib.pyplot as plt
19
20
# img helper functions
21
22
def print_info(x):
23
    print(str(x.shape) + ' - Min: ' + str(x.min()) + ' - Mean: ' + str(x.mean()) + ' - Max: ' + str(x.max()))
24
    
25
def show_samples(x, y, num):
26
    two_d = True if len(x.shape) == 4 else False
27
    rnd = np.random.permutation(len(x))
28
    for i in range(0, num, 2):
29
        plt.figure(figsize=(15, 5))
30
        for j in range(2):
31
            plt.subplot(1,4,1+j*2)
32
            img = x[rnd[i+j], ..., 0] if two_d else x[rnd[i], 8+8*j, ..., 0]
33
            plt.axis('off')
34
            plt.imshow(img.astype('float32'))
35
            plt.subplot(1,4,2+j*2)
36
            if y[rnd[i]].shape[-1] == 1:
37
                img = y[rnd[i+j], ..., 0] if two_d else y[rnd[i], 8+8*j, ..., 0]
38
            else:
39
                img = y[rnd[i+j]] if two_d else y[rnd[i], 8+8*j]
40
            plt.axis('off')
41
            plt.imshow(img.astype('float32'))
42
        plt.show()
43
        
44
def show_samples_2d(x, num, titles=None, axis_off=True, size=(20,20)):
45
    assert(len(x) >= 1)
46
    if titles:
47
        assert(len(titles) == len(x))
48
    rnd = np.random.permutation(len(x[0]))
49
    for row in range(num):
50
        plt.figure(figsize=size)
51
        for col in range(len(x)):
52
            plt.subplot(1,len(x), col+1)
53
            img = x[col][rnd[row], ..., 0] if x[col][rnd[row]].shape[-1] == 1 else x[col][rnd[row]]
54
            if axis_off:
55
                plt.axis('off')
56
            if titles:
57
                plt.title(titles[col])
58
            plt.imshow(img.astype('float32'), cmap='gray')
59
        plt.show()
60
61
def shuffle(x, y):
62
    perm = np.random.permutation(len(x))
63
    x = x[perm]
64
    y = y[perm]
65
    return x, y
66
67
def split(x, y, tr_size):
68
    tr_size = int(len(x) * tr_size)
69
    x_tr = x[:tr_size]
70
    y_tr = y[:tr_size]
71
    x_te = x[tr_size:]
72
    y_te = y[tr_size:]
73
    return x_tr, y_tr, x_te, y_te
74
75
def augment(x, y, h_shift=[], v_flip=False, h_flip=False, rot90=False, edge_mode='minimum'):
76
    assert(len(x.shape) == 4)
77
    seg = False if len(y.shape) <= 2 else True
78
    if h_shift and h_shift != 0 and len(h_shift) != 0:
79
        tmp_x, tmp_y = [], []
80
        for shft in h_shift:
81
            if shft > 0:
82
                tmp = np.lib.pad(x[:, :, :-shft], ((0,0), (0,0), (shft,0), (0,0)), edge_mode)
83
                tmp_x.append(tmp)
84
                if seg:
85
                    tmp = np.lib.pad(y[:, :, :-shft], ((0,0), (0,0), (shft,0), (0,0)), edge_mode)
86
                else:
87
                    tmp = y
88
                tmp_y.append(tmp)
89
            else:
90
                tmp = np.lib.pad(x[:, :, -shft:], ((0,0), (0,0), (0,-shft), (0,0)), edge_mode)
91
                tmp_x.append(tmp)
92
                if seg:
93
                    tmp = np.lib.pad(y[:, :, -shft:], ((0,0), (0,0), (0,-shft), (0,0)), edge_mode)
94
                else:
95
                    tmp = y
96
                tmp_y.append(tmp)
97
        x = np.concatenate((x, np.concatenate(tmp_x)))
98
        y = np.concatenate((y, np.concatenate(tmp_y)))
99
    if v_flip:
100
        tmp = np.flip(x, axis=1)
101
        x = np.concatenate((x, tmp))
102
        if seg:
103
            tmp = np.flip(y, axis=1)
104
            y = np.concatenate((y, tmp))
105
        else:
106
            y = np.concatenate((y, y))
107
    if h_flip:
108
        tmp = np.flip(x, axis=2)
109
        x = np.concatenate((x, tmp))
110
        if seg:
111
            tmp = np.flip(y, axis=2)
112
            y = np.concatenate((y, tmp))
113
        else:
114
            y = np.concatenate((y, y))
115
    if rot90:
116
        tmp = np.rot90(x, axes=(1,2))
117
        x = np.concatenate((x, tmp))
118
        if seg:
119
            tmp = np.rot90(y, axes=(1,2))
120
            y = np.concatenate((y, tmp))
121
        else:
122
            y = np.concatenate((y, y))
123
    return x, y
124
125
def resize_3d(img, size):
126
    img2 = np.zeros((img.shape[0], size[0], size[1], img.shape[-1]))
127
    for i in range(img.shape[0]):
128
        img2[i] = tr.resize(img[i], (size[0], size[1]), mode='constant', preserve_range=True)
129
    return img2
130
131
def to_2d(x):
132
    assert len(x.shape) == 5 # Shape: (#, Z, Y, X, C)
133
    return np.reshape(x, (x.shape[0]*x.shape[1], x.shape[2], x.shape[3], x.shape[4]))
134
135
def to_3d(imgs, z):
136
    assert len(imgs.shape) == 4 # Shape: (#, Y, X, C)
137
    return np.reshape(imgs, (imgs.shape[0] / z, z, imgs.shape[1], imgs.shape[2], imgs.shape[3]))
138
139
def get_crop_area(img, threshold=0):
140
    y_arr = np.where(img.sum(axis=0) > threshold)[0]
141
    size = y_arr[-1] - y_arr[0] + 1
142
    y = y_arr[0]
143
    x_arr = np.where(img.sum(axis=0).sum(axis=0) > threshold)[0]
144
    x = (x_arr[0] + x_arr[-1]) // 2 - size // 2
145
    return y, x, size
146
147
def n4_bias_correction(img):
148
    img = sitk.GetImageFromArray(img[..., 0].astype('float32'))
149
    mask = sitk.OtsuThreshold(img, 0, 1, 200)
150
    img = sitk.N4BiasFieldCorrection(img, mask)
151
    return sitk.GetArrayFromImage(img)[..., np.newaxis]
152
153
def handle_specials(img):
154
    if img.shape[0] == 26:
155
        img = img[1:-1]
156
    elif img.shape[0] == 20:
157
        img = np.lib.pad(img, ((2,2), (0,0), (0,0), (0,0)), 'minimum')
158
    return img
159
160
def erode(imgs, amount=3):
161
    imgs = imgs.sum(axis=-1)
162
    for i in range(len(imgs)):
163
        imgs[i] = mo.erosion(imgs[i], mo.square(amount))
164
    return imgs[..., np.newaxis]
165
166
def add_noise(imgs, amount=3):
167
    imgs = imgs.sum(axis=-1)
168
    for i in range(len(imgs)):
169
        if i % 2 == 0:
170
            imgs[i] = mo.dilation(imgs[i], mo.square(amount))
171
        else:
172
            imgs[i] = mo.erosion(imgs[i], mo.square(amount))
173
    return imgs[..., np.newaxis]
174
            
175
176
# Label helper functions
177
178
def to_classes(y, start, end, step=1):
179
    age_range = end - start
180
    num_classes = int(round(age_range / step))
181
    labels = np.zeros((len(y), num_classes))
182
    idx = (y - start) / step
183
    for i in range(len(idx)):
184
        labels[i, int(idx[i])] = 1
185
    return labels
186
187
def y_center(img, smooth=20, crop=100):
188
    # Get Sum of y-axis values
189
    y = img.sum(axis=-1).sum(axis=-1).sum(axis=0)
190
    # Smooth the values and apply the crop region
191
    y_vec = np.convolve(y, np.ones(smooth)/smooth, mode='same')[crop:-crop]
192
    # 2nd derivative of min will be max - get its index
193
    return np.gradient(np.gradient(y_vec)).argmax() + crop
194
195
def lengthen(y, factor):
196
    arr = []
197
    for el in y:
198
        for i in range(factor):
199
            arr.append(el)
200
    return np.array(arr)
201
202
def shorten(y, factor):
203
    arr = []
204
    for i in range(0, len(y), factor):
205
        arr.append(y[i])
206
    return np.array(arr)
207
208
def normalize(x, mean, std):
209
    return (x - x.mean()) / x.std()
210
211
def multilabel(img, channel):
212
    if channel == 1:
213
        img[img > 0.01] = 1
214
        img[img < 0.01] = 0
215
        return img
216
    else:
217
        step = img.max() // channel
218
        divider = img.max() * 0.99
219
        img2 = np.zeros((img.shape[0], img.shape[1], img.shape[2], channel))
220
        for c in range(channel):
221
            img2[img[..., 0] > divider, c] = 1
222
            img[img[..., 0] > divider, 0] = 0
223
            divider -= step
224
        return img2
225
226
def read_mhd(path, label=0, crop=None, size=None, bias=False, norm=False):
227
    img = io.imread(path, plugin='simpleitk')[..., np.newaxis].astype('float64')
228
    img = handle_specials(img)
229
    img = multilabel(img, label) if label > 0 else img
230
    img = img[:, crop[0]:crop[0]+crop[2], crop[1]:crop[1]+crop[2]] if crop else img
231
    #img = img[:, crop[0]:-2*crop[1]+crop[0], crop[1]:-1*crop[1]] if crop else img
232
    img = resize_3d(img, size) if size else img
233
    img = n4_bias_correction(img) if bias else img
234
    img = (img - img.mean()) / img.std() if norm else img
235
    return img.astype('float32')
236
237
def load_data(path, label=0, size=(24,224,224), bias=False, norm=False, to2d=False):
238
    files = glob(path)
239
    x, y = [], []
240
    for i in tqdm(range(len(files))):
241
        img = read_mhd(files[i])
242
        top, left, dim = get_crop_area(img)
243
        img = read_mhd(files[i], label=label, crop=(top, left, dim), size=size)
244
        if to2d:
245
            for layer in img:
246
                y.append(layer)
247
        else:
248
            y.append(img)
249
        files[i] = files[i].replace('/VOI_LABEL/', '/MHD/', 1)
250
        files[i] = files[i].replace('_LABEL.', '_ORIG.', 1)
251
        img = read_mhd(files[i], crop=(top, left, dim), size=size, bias=bias, norm=norm)
252
        if to2d:
253
            for layer in img:
254
                x.append(layer)
255
        else:
256
            x.append(img)
257
    x = np.array(x)
258
    y = np.array(y)
259
    return x, y
260
261
def load_data_age(files, size=None, crop=None, bias=False, norm=False, 
262
                  to2d=False, smart_crop=False):
263
    files = glob(files)
264
    x, y = [], []
265
    for i in tqdm(range(len(files))):
266
        if crop:
267
            if smart_crop:
268
                img = read_mhd(files[i])
269
                c = y_center(img)
270
                crop[0] = c - crop[2] // 2
271
        img = read_mhd(files[i], crop=crop, size=size, bias=bias, norm=norm)
272
        f = files[i].split('_')
273
        age = int(f[3]) + int(f[4]) / 12.
274
        if to2d:
275
            for layer in img:
276
                x.append(layer)
277
                y.append(age)
278
        else:
279
            x.append(img)
280
            y.append(age)
281
    x = np.array(x)
282
    y = np.array(y)
283
    return x, y
284
285
def print_weights(weight_file_path):
286
    """
287
    Prints out the structure of HDF5 file.
288
289
    Args:
290
      weight_file_path (str) : Path to the file to analyze
291
    """
292
    f = h5py.File(weight_file_path)
293
    try:
294
        if len(f.attrs.items()):
295
            print("{} contains: ".format(weight_file_path))
296
            print("Root attributes:")
297
        for key, value in f.attrs.items():
298
            print("  {}: {}".format(key, value))
299
300
        if len(f.items())==0:
301
            return 
302
303
        for layer, g in f.items():
304
            print("  {}".format(layer))
305
            print("    Attributes:")
306
            for key, value in g.attrs.items():
307
                print("      {}: {}".format(key, value))
308
309
            print("    Dataset:")
310
            for p_name in g.keys():
311
                param = g[p_name]
312
                print("      {}: {}".format(p_name, param.shape)) #try only "param"
313
    finally:
314
        f.close()
315
316
# Models
317
318
def conv_block(m, dim, acti, bn, res, do=0):
319
    n = Conv2D(dim, 3, activation=acti, padding='same')(m)
320
    n = BatchNormalization()(n) if bn else n
321
    n = Dropout(do)(n) if do else n
322
    n = Conv2D(dim, 3, activation=acti, padding='same')(n)
323
    n = BatchNormalization()(n) if bn else n
324
    return Add()([m, n]) if res else n
325
326
def level_block(m, dim, depth, inc, acti, do, bn, mp, up, res):
327
    if depth > 0:
328
        n = conv_block(m, dim, acti, bn, res)
329
        m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n)
330
        m = level_block(m, int(inc*dim), depth-1, inc, acti, do, bn, mp, up, res)
331
        if up:
332
            m = UpSampling2D()(m)
333
            m = Conv2D(dim, 2, activation=acti, padding='same')(m)
334
        else:
335
            m = Conv2DTranspose(dim, 3, strides=2, activation=acti, padding='same')(m)
336
        n = Add()([n, m])
337
        m = conv_block(n, dim, acti, bn, res)
338
    else:
339
        m = conv_block(m, dim, acti, bn, res, do)
340
    return m
341
342
def UNet(img_shape, out_ch=1, start_ch=32, depth=4, inc_rate=1., activation='elu', 
343
         dropout=0.5, batchnorm=False, maxpool=True, upconv=True, residual=False):
344
    i = Input(shape=img_shape)
345
    o = level_block(i, start_ch, depth, inc_rate, activation, dropout, batchnorm, maxpool, upconv, residual)
346
    o = Conv2D(out_ch, 1, activation='sigmoid')(o)
347
    return Model(inputs=i, outputs=o)
348
349
def level_block_3d(m, dim, depth, factor, acti, dropout):
350
    if depth > 0:
351
        n = Conv3D(dim, 3, activation=acti, padding='same')(m)
352
        n = Dropout(dropout)(n) if dropout else n
353
        n = Conv3D(dim, 3, activation=acti, padding='same')(n)
354
        m = MaxPooling3D((1,2,2))(n)
355
        m = level_block_3d(m, int(factor*dim), depth-1, factor, acti, dropout)
356
        m = UpSampling3D((1,2,2))(m)
357
        m = Conv3D(dim, 2, activation=acti, padding='same')(m)
358
        m = Concatenate(axis=4)([n, m])
359
    m = Conv3D(dim, 3, activation=acti, padding='same')(m)
360
    return Conv3D(dim, 3, activation=acti, padding='same')(m)
361
362
def UNet_3D(img_shape, n_out=1, dim=8, depth=3, factor=1.5, acti='elu', dropout=None):
363
    i = Input(shape=img_shape)
364
    o = level_block_3d(i, dim, depth, factor, acti, dropout)
365
    o = Conv3D(n_out, 1, activation='sigmoid')(o)
366
    return Model(inputs=i, outputs=o)
367
368
# Loss Functions
369
370
# 2TP / (2TP + FP + FN)
371
def f1(y_true, y_pred):
372
    y_true_f = K.flatten(y_true)
373
    y_pred_f = K.flatten(y_pred)
374
    intersection = K.sum(y_true_f * y_pred_f)
375
    return (2. * intersection + 1.) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.)
376
377
def f1_np(y_true, y_pred):
378
    return (2. * (y_true * y_pred).sum() + 1.) / (y_true.sum() + y_pred.sum() + 1.)
379
380
def f1_loss(y_true, y_pred):
381
    return 1-f1(y_true, y_pred)
382
383
def f2(y_true, y_pred):
384
    y_true_f = K.flatten(y_true)
385
    y_pred_f = K.flatten(y_pred)
386
    intersection = K.sum(y_true_f * y_pred_f)
387
    return (5. * intersection + 1.) / (4. * K.sum(y_true_f) + K.sum(y_pred_f) + 1.)
388
389
def f2_loss(y_true, y_pred):
390
    return 1-f2(y_true, y_pred)
391
392
dice = f1
393
dice_loss = f1_loss
394
395
def iou(y_true, y_pred):
396
    y_true_f = K.flatten(y_true)
397
    y_pred_f = K.flatten(y_pred)
398
    intersection = K.sum(y_true_f * y_pred_f)
399
    return (intersection + 1.) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1. - intersection)
400
401
def iou_np(y_true, y_pred):
402
    intersection = (y_true * y_pred).sum()
403
    return (intersection + 1.) / (y_true.sum() + y_pred.sum() + 1. - intersection)
404
405
def iou_loss(y_true, y_pred):
406
    return -iou(y_true, y_pred)
407
408
def precision(y_true, y_pred):
409
    y_true_f = K.flatten(y_true)
410
    y_pred_f = K.flatten(y_pred)
411
    intersection = K.sum(y_true_f * y_pred_f)
412
    return (intersection + 1.) / (K.sum(y_pred_f) + 1.)
413
414
def precision_np(y_true, y_pred):
415
    return ((y_true * y_pred).sum() + 1.) / (y_pred.sum() + 1.)
416
417
def recall(y_true, y_pred):
418
    y_true_f = K.flatten(y_true)
419
    y_pred_f = K.flatten(y_pred)
420
    intersection = K.sum(y_true_f * y_pred_f)
421
    return (intersection + 1.) / (K.sum(y_true_f) + 1.)
422
423
def recall_np(y_true, y_pred):
424
    return ((y_true * y_pred).sum() + 1.) / (y_true.sum() + 1.)
425
426
def mae_img(y_true, y_pred):
427
    y_true_f = K.flatten(y_true)
428
    y_pred_f = K.flatten(y_pred)
429
    return mae(y_true_f, y_pred_f)
430
431
def bce_img(y_true, y_pred):
432
    y_true_f = K.flatten(y_true)
433
    y_pred_f = K.flatten(y_pred)
434
    return binary_crossentropy(y_true_f, y_pred_f)
435
436
def f1_bce(y_true, y_pred):
437
    return f1_loss(y_true, y_pred) + bce_img(y_true, y_pred)
438
439
# FP + FN
440
def error(y_true, y_pred):
441
    y_true_f = K.flatten(y_true)
442
    y_pred_f = K.flatten(y_pred)
443
    return K.sum(K.abs(y_true_f - y_pred_f)) / float(224*224)
444
445
def error_np(y_true, y_pred):
446
    return (abs(y_true - y_pred)).sum() / float(len(y_true.flatten()))
447
448
# Notifications
449
    
450
def pushover(title, message):
451
    user = "u96ub3t5wu1nexmgi22xjs31jeb8y6"
452
    api = "avfytsyktracxood45myebobtry6yd"
453
    client = Client(user, api_token=api)
454
    client.send_message(message, title=title)
455
    
456
#from nipype.interfaces.ants import N4BiasFieldCorrection
457
#correct = N4BiasFieldCorrection()
458
#correct.inputs.input_image = in_file
459
#correct.inputs.output_image = out_file
460
#done = correct.run()
461
#img done.outputs.output_image