Diff of /utils.py [000000] .. [7e66db]

Switch to side-by-side view

--- a
+++ b/utils.py
@@ -0,0 +1,461 @@
+#!/usr/bin/python
+import numpy as np
+from keras.models import *
+from keras.layers import *
+from keras.optimizers import *
+from keras.callbacks import *
+from keras.losses import *
+from keras.preprocessing.image import *
+from os.path import isfile
+from tqdm import tqdm
+import random
+from glob import glob
+import skimage.io as io
+import skimage.transform as tr
+import skimage.morphology as mo
+import SimpleITK as sitk
+from pushover import Client
+import matplotlib.pyplot as plt
+
+# img helper functions
+
+def print_info(x):
+    print(str(x.shape) + ' - Min: ' + str(x.min()) + ' - Mean: ' + str(x.mean()) + ' - Max: ' + str(x.max()))
+    
+def show_samples(x, y, num):
+    two_d = True if len(x.shape) == 4 else False
+    rnd = np.random.permutation(len(x))
+    for i in range(0, num, 2):
+        plt.figure(figsize=(15, 5))
+        for j in range(2):
+            plt.subplot(1,4,1+j*2)
+            img = x[rnd[i+j], ..., 0] if two_d else x[rnd[i], 8+8*j, ..., 0]
+            plt.axis('off')
+            plt.imshow(img.astype('float32'))
+            plt.subplot(1,4,2+j*2)
+            if y[rnd[i]].shape[-1] == 1:
+                img = y[rnd[i+j], ..., 0] if two_d else y[rnd[i], 8+8*j, ..., 0]
+            else:
+                img = y[rnd[i+j]] if two_d else y[rnd[i], 8+8*j]
+            plt.axis('off')
+            plt.imshow(img.astype('float32'))
+        plt.show()
+        
+def show_samples_2d(x, num, titles=None, axis_off=True, size=(20,20)):
+    assert(len(x) >= 1)
+    if titles:
+        assert(len(titles) == len(x))
+    rnd = np.random.permutation(len(x[0]))
+    for row in range(num):
+        plt.figure(figsize=size)
+        for col in range(len(x)):
+            plt.subplot(1,len(x), col+1)
+            img = x[col][rnd[row], ..., 0] if x[col][rnd[row]].shape[-1] == 1 else x[col][rnd[row]]
+            if axis_off:
+                plt.axis('off')
+            if titles:
+                plt.title(titles[col])
+            plt.imshow(img.astype('float32'), cmap='gray')
+        plt.show()
+
+def shuffle(x, y):
+    perm = np.random.permutation(len(x))
+    x = x[perm]
+    y = y[perm]
+    return x, y
+
+def split(x, y, tr_size):
+    tr_size = int(len(x) * tr_size)
+    x_tr = x[:tr_size]
+    y_tr = y[:tr_size]
+    x_te = x[tr_size:]
+    y_te = y[tr_size:]
+    return x_tr, y_tr, x_te, y_te
+
+def augment(x, y, h_shift=[], v_flip=False, h_flip=False, rot90=False, edge_mode='minimum'):
+    assert(len(x.shape) == 4)
+    seg = False if len(y.shape) <= 2 else True
+    if h_shift and h_shift != 0 and len(h_shift) != 0:
+        tmp_x, tmp_y = [], []
+        for shft in h_shift:
+            if shft > 0:
+                tmp = np.lib.pad(x[:, :, :-shft], ((0,0), (0,0), (shft,0), (0,0)), edge_mode)
+                tmp_x.append(tmp)
+                if seg:
+                    tmp = np.lib.pad(y[:, :, :-shft], ((0,0), (0,0), (shft,0), (0,0)), edge_mode)
+                else:
+                    tmp = y
+                tmp_y.append(tmp)
+            else:
+                tmp = np.lib.pad(x[:, :, -shft:], ((0,0), (0,0), (0,-shft), (0,0)), edge_mode)
+                tmp_x.append(tmp)
+                if seg:
+                    tmp = np.lib.pad(y[:, :, -shft:], ((0,0), (0,0), (0,-shft), (0,0)), edge_mode)
+                else:
+                    tmp = y
+                tmp_y.append(tmp)
+        x = np.concatenate((x, np.concatenate(tmp_x)))
+        y = np.concatenate((y, np.concatenate(tmp_y)))
+    if v_flip:
+        tmp = np.flip(x, axis=1)
+        x = np.concatenate((x, tmp))
+        if seg:
+            tmp = np.flip(y, axis=1)
+            y = np.concatenate((y, tmp))
+        else:
+            y = np.concatenate((y, y))
+    if h_flip:
+        tmp = np.flip(x, axis=2)
+        x = np.concatenate((x, tmp))
+        if seg:
+            tmp = np.flip(y, axis=2)
+            y = np.concatenate((y, tmp))
+        else:
+            y = np.concatenate((y, y))
+    if rot90:
+        tmp = np.rot90(x, axes=(1,2))
+        x = np.concatenate((x, tmp))
+        if seg:
+            tmp = np.rot90(y, axes=(1,2))
+            y = np.concatenate((y, tmp))
+        else:
+            y = np.concatenate((y, y))
+    return x, y
+
+def resize_3d(img, size):
+    img2 = np.zeros((img.shape[0], size[0], size[1], img.shape[-1]))
+    for i in range(img.shape[0]):
+        img2[i] = tr.resize(img[i], (size[0], size[1]), mode='constant', preserve_range=True)
+    return img2
+
+def to_2d(x):
+    assert len(x.shape) == 5 # Shape: (#, Z, Y, X, C)
+    return np.reshape(x, (x.shape[0]*x.shape[1], x.shape[2], x.shape[3], x.shape[4]))
+
+def to_3d(imgs, z):
+    assert len(imgs.shape) == 4 # Shape: (#, Y, X, C)
+    return np.reshape(imgs, (imgs.shape[0] / z, z, imgs.shape[1], imgs.shape[2], imgs.shape[3]))
+
+def get_crop_area(img, threshold=0):
+    y_arr = np.where(img.sum(axis=0) > threshold)[0]
+    size = y_arr[-1] - y_arr[0] + 1
+    y = y_arr[0]
+    x_arr = np.where(img.sum(axis=0).sum(axis=0) > threshold)[0]
+    x = (x_arr[0] + x_arr[-1]) // 2 - size // 2
+    return y, x, size
+
+def n4_bias_correction(img):
+    img = sitk.GetImageFromArray(img[..., 0].astype('float32'))
+    mask = sitk.OtsuThreshold(img, 0, 1, 200)
+    img = sitk.N4BiasFieldCorrection(img, mask)
+    return sitk.GetArrayFromImage(img)[..., np.newaxis]
+
+def handle_specials(img):
+    if img.shape[0] == 26:
+        img = img[1:-1]
+    elif img.shape[0] == 20:
+        img = np.lib.pad(img, ((2,2), (0,0), (0,0), (0,0)), 'minimum')
+    return img
+
+def erode(imgs, amount=3):
+    imgs = imgs.sum(axis=-1)
+    for i in range(len(imgs)):
+        imgs[i] = mo.erosion(imgs[i], mo.square(amount))
+    return imgs[..., np.newaxis]
+
+def add_noise(imgs, amount=3):
+    imgs = imgs.sum(axis=-1)
+    for i in range(len(imgs)):
+        if i % 2 == 0:
+            imgs[i] = mo.dilation(imgs[i], mo.square(amount))
+        else:
+            imgs[i] = mo.erosion(imgs[i], mo.square(amount))
+    return imgs[..., np.newaxis]
+            
+
+# Label helper functions
+
+def to_classes(y, start, end, step=1):
+    age_range = end - start
+    num_classes = int(round(age_range / step))
+    labels = np.zeros((len(y), num_classes))
+    idx = (y - start) / step
+    for i in range(len(idx)):
+        labels[i, int(idx[i])] = 1
+    return labels
+
+def y_center(img, smooth=20, crop=100):
+    # Get Sum of y-axis values
+    y = img.sum(axis=-1).sum(axis=-1).sum(axis=0)
+    # Smooth the values and apply the crop region
+    y_vec = np.convolve(y, np.ones(smooth)/smooth, mode='same')[crop:-crop]
+    # 2nd derivative of min will be max - get its index
+    return np.gradient(np.gradient(y_vec)).argmax() + crop
+
+def lengthen(y, factor):
+    arr = []
+    for el in y:
+        for i in range(factor):
+            arr.append(el)
+    return np.array(arr)
+
+def shorten(y, factor):
+    arr = []
+    for i in range(0, len(y), factor):
+        arr.append(y[i])
+    return np.array(arr)
+
+def normalize(x, mean, std):
+    return (x - x.mean()) / x.std()
+
+def multilabel(img, channel):
+    if channel == 1:
+        img[img > 0.01] = 1
+        img[img < 0.01] = 0
+        return img
+    else:
+        step = img.max() // channel
+        divider = img.max() * 0.99
+        img2 = np.zeros((img.shape[0], img.shape[1], img.shape[2], channel))
+        for c in range(channel):
+            img2[img[..., 0] > divider, c] = 1
+            img[img[..., 0] > divider, 0] = 0
+            divider -= step
+        return img2
+
+def read_mhd(path, label=0, crop=None, size=None, bias=False, norm=False):
+    img = io.imread(path, plugin='simpleitk')[..., np.newaxis].astype('float64')
+    img = handle_specials(img)
+    img = multilabel(img, label) if label > 0 else img
+    img = img[:, crop[0]:crop[0]+crop[2], crop[1]:crop[1]+crop[2]] if crop else img
+    #img = img[:, crop[0]:-2*crop[1]+crop[0], crop[1]:-1*crop[1]] if crop else img
+    img = resize_3d(img, size) if size else img
+    img = n4_bias_correction(img) if bias else img
+    img = (img - img.mean()) / img.std() if norm else img
+    return img.astype('float32')
+
+def load_data(path, label=0, size=(24,224,224), bias=False, norm=False, to2d=False):
+    files = glob(path)
+    x, y = [], []
+    for i in tqdm(range(len(files))):
+        img = read_mhd(files[i])
+        top, left, dim = get_crop_area(img)
+        img = read_mhd(files[i], label=label, crop=(top, left, dim), size=size)
+        if to2d:
+            for layer in img:
+                y.append(layer)
+        else:
+            y.append(img)
+        files[i] = files[i].replace('/VOI_LABEL/', '/MHD/', 1)
+        files[i] = files[i].replace('_LABEL.', '_ORIG.', 1)
+        img = read_mhd(files[i], crop=(top, left, dim), size=size, bias=bias, norm=norm)
+        if to2d:
+            for layer in img:
+                x.append(layer)
+        else:
+            x.append(img)
+    x = np.array(x)
+    y = np.array(y)
+    return x, y
+
+def load_data_age(files, size=None, crop=None, bias=False, norm=False, 
+                  to2d=False, smart_crop=False):
+    files = glob(files)
+    x, y = [], []
+    for i in tqdm(range(len(files))):
+        if crop:
+            if smart_crop:
+                img = read_mhd(files[i])
+                c = y_center(img)
+                crop[0] = c - crop[2] // 2
+        img = read_mhd(files[i], crop=crop, size=size, bias=bias, norm=norm)
+        f = files[i].split('_')
+        age = int(f[3]) + int(f[4]) / 12.
+        if to2d:
+            for layer in img:
+                x.append(layer)
+                y.append(age)
+        else:
+            x.append(img)
+            y.append(age)
+    x = np.array(x)
+    y = np.array(y)
+    return x, y
+
+def print_weights(weight_file_path):
+    """
+    Prints out the structure of HDF5 file.
+
+    Args:
+      weight_file_path (str) : Path to the file to analyze
+    """
+    f = h5py.File(weight_file_path)
+    try:
+        if len(f.attrs.items()):
+            print("{} contains: ".format(weight_file_path))
+            print("Root attributes:")
+        for key, value in f.attrs.items():
+            print("  {}: {}".format(key, value))
+
+        if len(f.items())==0:
+            return 
+
+        for layer, g in f.items():
+            print("  {}".format(layer))
+            print("    Attributes:")
+            for key, value in g.attrs.items():
+                print("      {}: {}".format(key, value))
+
+            print("    Dataset:")
+            for p_name in g.keys():
+                param = g[p_name]
+                print("      {}: {}".format(p_name, param.shape)) #try only "param"
+    finally:
+        f.close()
+
+# Models
+
+def conv_block(m, dim, acti, bn, res, do=0):
+    n = Conv2D(dim, 3, activation=acti, padding='same')(m)
+    n = BatchNormalization()(n) if bn else n
+    n = Dropout(do)(n) if do else n
+    n = Conv2D(dim, 3, activation=acti, padding='same')(n)
+    n = BatchNormalization()(n) if bn else n
+    return Add()([m, n]) if res else n
+
+def level_block(m, dim, depth, inc, acti, do, bn, mp, up, res):
+    if depth > 0:
+        n = conv_block(m, dim, acti, bn, res)
+        m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n)
+        m = level_block(m, int(inc*dim), depth-1, inc, acti, do, bn, mp, up, res)
+        if up:
+            m = UpSampling2D()(m)
+            m = Conv2D(dim, 2, activation=acti, padding='same')(m)
+        else:
+            m = Conv2DTranspose(dim, 3, strides=2, activation=acti, padding='same')(m)
+        n = Add()([n, m])
+        m = conv_block(n, dim, acti, bn, res)
+    else:
+        m = conv_block(m, dim, acti, bn, res, do)
+    return m
+
+def UNet(img_shape, out_ch=1, start_ch=32, depth=4, inc_rate=1., activation='elu', 
+         dropout=0.5, batchnorm=False, maxpool=True, upconv=True, residual=False):
+    i = Input(shape=img_shape)
+    o = level_block(i, start_ch, depth, inc_rate, activation, dropout, batchnorm, maxpool, upconv, residual)
+    o = Conv2D(out_ch, 1, activation='sigmoid')(o)
+    return Model(inputs=i, outputs=o)
+
+def level_block_3d(m, dim, depth, factor, acti, dropout):
+    if depth > 0:
+        n = Conv3D(dim, 3, activation=acti, padding='same')(m)
+        n = Dropout(dropout)(n) if dropout else n
+        n = Conv3D(dim, 3, activation=acti, padding='same')(n)
+        m = MaxPooling3D((1,2,2))(n)
+        m = level_block_3d(m, int(factor*dim), depth-1, factor, acti, dropout)
+        m = UpSampling3D((1,2,2))(m)
+        m = Conv3D(dim, 2, activation=acti, padding='same')(m)
+        m = Concatenate(axis=4)([n, m])
+    m = Conv3D(dim, 3, activation=acti, padding='same')(m)
+    return Conv3D(dim, 3, activation=acti, padding='same')(m)
+
+def UNet_3D(img_shape, n_out=1, dim=8, depth=3, factor=1.5, acti='elu', dropout=None):
+    i = Input(shape=img_shape)
+    o = level_block_3d(i, dim, depth, factor, acti, dropout)
+    o = Conv3D(n_out, 1, activation='sigmoid')(o)
+    return Model(inputs=i, outputs=o)
+
+# Loss Functions
+
+# 2TP / (2TP + FP + FN)
+def f1(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (2. * intersection + 1.) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.)
+
+def f1_np(y_true, y_pred):
+    return (2. * (y_true * y_pred).sum() + 1.) / (y_true.sum() + y_pred.sum() + 1.)
+
+def f1_loss(y_true, y_pred):
+    return 1-f1(y_true, y_pred)
+
+def f2(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (5. * intersection + 1.) / (4. * K.sum(y_true_f) + K.sum(y_pred_f) + 1.)
+
+def f2_loss(y_true, y_pred):
+    return 1-f2(y_true, y_pred)
+
+dice = f1
+dice_loss = f1_loss
+
+def iou(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (intersection + 1.) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1. - intersection)
+
+def iou_np(y_true, y_pred):
+    intersection = (y_true * y_pred).sum()
+    return (intersection + 1.) / (y_true.sum() + y_pred.sum() + 1. - intersection)
+
+def iou_loss(y_true, y_pred):
+    return -iou(y_true, y_pred)
+
+def precision(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (intersection + 1.) / (K.sum(y_pred_f) + 1.)
+
+def precision_np(y_true, y_pred):
+    return ((y_true * y_pred).sum() + 1.) / (y_pred.sum() + 1.)
+
+def recall(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    intersection = K.sum(y_true_f * y_pred_f)
+    return (intersection + 1.) / (K.sum(y_true_f) + 1.)
+
+def recall_np(y_true, y_pred):
+    return ((y_true * y_pred).sum() + 1.) / (y_true.sum() + 1.)
+
+def mae_img(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    return mae(y_true_f, y_pred_f)
+
+def bce_img(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    return binary_crossentropy(y_true_f, y_pred_f)
+
+def f1_bce(y_true, y_pred):
+    return f1_loss(y_true, y_pred) + bce_img(y_true, y_pred)
+
+# FP + FN
+def error(y_true, y_pred):
+    y_true_f = K.flatten(y_true)
+    y_pred_f = K.flatten(y_pred)
+    return K.sum(K.abs(y_true_f - y_pred_f)) / float(224*224)
+
+def error_np(y_true, y_pred):
+    return (abs(y_true - y_pred)).sum() / float(len(y_true.flatten()))
+
+# Notifications
+    
+def pushover(title, message):
+    user = "u96ub3t5wu1nexmgi22xjs31jeb8y6"
+    api = "avfytsyktracxood45myebobtry6yd"
+    client = Client(user, api_token=api)
+    client.send_message(message, title=title)
+    
+#from nipype.interfaces.ants import N4BiasFieldCorrection
+#correct = N4BiasFieldCorrection()
+#correct.inputs.input_image = in_file
+#correct.inputs.output_image = out_file
+#done = correct.run()
+#img done.outputs.output_image
\ No newline at end of file