--- a +++ b/utils.py @@ -0,0 +1,461 @@ +#!/usr/bin/python +import numpy as np +from keras.models import * +from keras.layers import * +from keras.optimizers import * +from keras.callbacks import * +from keras.losses import * +from keras.preprocessing.image import * +from os.path import isfile +from tqdm import tqdm +import random +from glob import glob +import skimage.io as io +import skimage.transform as tr +import skimage.morphology as mo +import SimpleITK as sitk +from pushover import Client +import matplotlib.pyplot as plt + +# img helper functions + +def print_info(x): + print(str(x.shape) + ' - Min: ' + str(x.min()) + ' - Mean: ' + str(x.mean()) + ' - Max: ' + str(x.max())) + +def show_samples(x, y, num): + two_d = True if len(x.shape) == 4 else False + rnd = np.random.permutation(len(x)) + for i in range(0, num, 2): + plt.figure(figsize=(15, 5)) + for j in range(2): + plt.subplot(1,4,1+j*2) + img = x[rnd[i+j], ..., 0] if two_d else x[rnd[i], 8+8*j, ..., 0] + plt.axis('off') + plt.imshow(img.astype('float32')) + plt.subplot(1,4,2+j*2) + if y[rnd[i]].shape[-1] == 1: + img = y[rnd[i+j], ..., 0] if two_d else y[rnd[i], 8+8*j, ..., 0] + else: + img = y[rnd[i+j]] if two_d else y[rnd[i], 8+8*j] + plt.axis('off') + plt.imshow(img.astype('float32')) + plt.show() + +def show_samples_2d(x, num, titles=None, axis_off=True, size=(20,20)): + assert(len(x) >= 1) + if titles: + assert(len(titles) == len(x)) + rnd = np.random.permutation(len(x[0])) + for row in range(num): + plt.figure(figsize=size) + for col in range(len(x)): + plt.subplot(1,len(x), col+1) + img = x[col][rnd[row], ..., 0] if x[col][rnd[row]].shape[-1] == 1 else x[col][rnd[row]] + if axis_off: + plt.axis('off') + if titles: + plt.title(titles[col]) + plt.imshow(img.astype('float32'), cmap='gray') + plt.show() + +def shuffle(x, y): + perm = np.random.permutation(len(x)) + x = x[perm] + y = y[perm] + return x, y + +def split(x, y, tr_size): + tr_size = int(len(x) * tr_size) + x_tr = x[:tr_size] + y_tr = y[:tr_size] + x_te = x[tr_size:] + y_te = y[tr_size:] + return x_tr, y_tr, x_te, y_te + +def augment(x, y, h_shift=[], v_flip=False, h_flip=False, rot90=False, edge_mode='minimum'): + assert(len(x.shape) == 4) + seg = False if len(y.shape) <= 2 else True + if h_shift and h_shift != 0 and len(h_shift) != 0: + tmp_x, tmp_y = [], [] + for shft in h_shift: + if shft > 0: + tmp = np.lib.pad(x[:, :, :-shft], ((0,0), (0,0), (shft,0), (0,0)), edge_mode) + tmp_x.append(tmp) + if seg: + tmp = np.lib.pad(y[:, :, :-shft], ((0,0), (0,0), (shft,0), (0,0)), edge_mode) + else: + tmp = y + tmp_y.append(tmp) + else: + tmp = np.lib.pad(x[:, :, -shft:], ((0,0), (0,0), (0,-shft), (0,0)), edge_mode) + tmp_x.append(tmp) + if seg: + tmp = np.lib.pad(y[:, :, -shft:], ((0,0), (0,0), (0,-shft), (0,0)), edge_mode) + else: + tmp = y + tmp_y.append(tmp) + x = np.concatenate((x, np.concatenate(tmp_x))) + y = np.concatenate((y, np.concatenate(tmp_y))) + if v_flip: + tmp = np.flip(x, axis=1) + x = np.concatenate((x, tmp)) + if seg: + tmp = np.flip(y, axis=1) + y = np.concatenate((y, tmp)) + else: + y = np.concatenate((y, y)) + if h_flip: + tmp = np.flip(x, axis=2) + x = np.concatenate((x, tmp)) + if seg: + tmp = np.flip(y, axis=2) + y = np.concatenate((y, tmp)) + else: + y = np.concatenate((y, y)) + if rot90: + tmp = np.rot90(x, axes=(1,2)) + x = np.concatenate((x, tmp)) + if seg: + tmp = np.rot90(y, axes=(1,2)) + y = np.concatenate((y, tmp)) + else: + y = np.concatenate((y, y)) + return x, y + +def resize_3d(img, size): + img2 = np.zeros((img.shape[0], size[0], size[1], img.shape[-1])) + for i in range(img.shape[0]): + img2[i] = tr.resize(img[i], (size[0], size[1]), mode='constant', preserve_range=True) + return img2 + +def to_2d(x): + assert len(x.shape) == 5 # Shape: (#, Z, Y, X, C) + return np.reshape(x, (x.shape[0]*x.shape[1], x.shape[2], x.shape[3], x.shape[4])) + +def to_3d(imgs, z): + assert len(imgs.shape) == 4 # Shape: (#, Y, X, C) + return np.reshape(imgs, (imgs.shape[0] / z, z, imgs.shape[1], imgs.shape[2], imgs.shape[3])) + +def get_crop_area(img, threshold=0): + y_arr = np.where(img.sum(axis=0) > threshold)[0] + size = y_arr[-1] - y_arr[0] + 1 + y = y_arr[0] + x_arr = np.where(img.sum(axis=0).sum(axis=0) > threshold)[0] + x = (x_arr[0] + x_arr[-1]) // 2 - size // 2 + return y, x, size + +def n4_bias_correction(img): + img = sitk.GetImageFromArray(img[..., 0].astype('float32')) + mask = sitk.OtsuThreshold(img, 0, 1, 200) + img = sitk.N4BiasFieldCorrection(img, mask) + return sitk.GetArrayFromImage(img)[..., np.newaxis] + +def handle_specials(img): + if img.shape[0] == 26: + img = img[1:-1] + elif img.shape[0] == 20: + img = np.lib.pad(img, ((2,2), (0,0), (0,0), (0,0)), 'minimum') + return img + +def erode(imgs, amount=3): + imgs = imgs.sum(axis=-1) + for i in range(len(imgs)): + imgs[i] = mo.erosion(imgs[i], mo.square(amount)) + return imgs[..., np.newaxis] + +def add_noise(imgs, amount=3): + imgs = imgs.sum(axis=-1) + for i in range(len(imgs)): + if i % 2 == 0: + imgs[i] = mo.dilation(imgs[i], mo.square(amount)) + else: + imgs[i] = mo.erosion(imgs[i], mo.square(amount)) + return imgs[..., np.newaxis] + + +# Label helper functions + +def to_classes(y, start, end, step=1): + age_range = end - start + num_classes = int(round(age_range / step)) + labels = np.zeros((len(y), num_classes)) + idx = (y - start) / step + for i in range(len(idx)): + labels[i, int(idx[i])] = 1 + return labels + +def y_center(img, smooth=20, crop=100): + # Get Sum of y-axis values + y = img.sum(axis=-1).sum(axis=-1).sum(axis=0) + # Smooth the values and apply the crop region + y_vec = np.convolve(y, np.ones(smooth)/smooth, mode='same')[crop:-crop] + # 2nd derivative of min will be max - get its index + return np.gradient(np.gradient(y_vec)).argmax() + crop + +def lengthen(y, factor): + arr = [] + for el in y: + for i in range(factor): + arr.append(el) + return np.array(arr) + +def shorten(y, factor): + arr = [] + for i in range(0, len(y), factor): + arr.append(y[i]) + return np.array(arr) + +def normalize(x, mean, std): + return (x - x.mean()) / x.std() + +def multilabel(img, channel): + if channel == 1: + img[img > 0.01] = 1 + img[img < 0.01] = 0 + return img + else: + step = img.max() // channel + divider = img.max() * 0.99 + img2 = np.zeros((img.shape[0], img.shape[1], img.shape[2], channel)) + for c in range(channel): + img2[img[..., 0] > divider, c] = 1 + img[img[..., 0] > divider, 0] = 0 + divider -= step + return img2 + +def read_mhd(path, label=0, crop=None, size=None, bias=False, norm=False): + img = io.imread(path, plugin='simpleitk')[..., np.newaxis].astype('float64') + img = handle_specials(img) + img = multilabel(img, label) if label > 0 else img + img = img[:, crop[0]:crop[0]+crop[2], crop[1]:crop[1]+crop[2]] if crop else img + #img = img[:, crop[0]:-2*crop[1]+crop[0], crop[1]:-1*crop[1]] if crop else img + img = resize_3d(img, size) if size else img + img = n4_bias_correction(img) if bias else img + img = (img - img.mean()) / img.std() if norm else img + return img.astype('float32') + +def load_data(path, label=0, size=(24,224,224), bias=False, norm=False, to2d=False): + files = glob(path) + x, y = [], [] + for i in tqdm(range(len(files))): + img = read_mhd(files[i]) + top, left, dim = get_crop_area(img) + img = read_mhd(files[i], label=label, crop=(top, left, dim), size=size) + if to2d: + for layer in img: + y.append(layer) + else: + y.append(img) + files[i] = files[i].replace('/VOI_LABEL/', '/MHD/', 1) + files[i] = files[i].replace('_LABEL.', '_ORIG.', 1) + img = read_mhd(files[i], crop=(top, left, dim), size=size, bias=bias, norm=norm) + if to2d: + for layer in img: + x.append(layer) + else: + x.append(img) + x = np.array(x) + y = np.array(y) + return x, y + +def load_data_age(files, size=None, crop=None, bias=False, norm=False, + to2d=False, smart_crop=False): + files = glob(files) + x, y = [], [] + for i in tqdm(range(len(files))): + if crop: + if smart_crop: + img = read_mhd(files[i]) + c = y_center(img) + crop[0] = c - crop[2] // 2 + img = read_mhd(files[i], crop=crop, size=size, bias=bias, norm=norm) + f = files[i].split('_') + age = int(f[3]) + int(f[4]) / 12. + if to2d: + for layer in img: + x.append(layer) + y.append(age) + else: + x.append(img) + y.append(age) + x = np.array(x) + y = np.array(y) + return x, y + +def print_weights(weight_file_path): + """ + Prints out the structure of HDF5 file. + + Args: + weight_file_path (str) : Path to the file to analyze + """ + f = h5py.File(weight_file_path) + try: + if len(f.attrs.items()): + print("{} contains: ".format(weight_file_path)) + print("Root attributes:") + for key, value in f.attrs.items(): + print(" {}: {}".format(key, value)) + + if len(f.items())==0: + return + + for layer, g in f.items(): + print(" {}".format(layer)) + print(" Attributes:") + for key, value in g.attrs.items(): + print(" {}: {}".format(key, value)) + + print(" Dataset:") + for p_name in g.keys(): + param = g[p_name] + print(" {}: {}".format(p_name, param.shape)) #try only "param" + finally: + f.close() + +# Models + +def conv_block(m, dim, acti, bn, res, do=0): + n = Conv2D(dim, 3, activation=acti, padding='same')(m) + n = BatchNormalization()(n) if bn else n + n = Dropout(do)(n) if do else n + n = Conv2D(dim, 3, activation=acti, padding='same')(n) + n = BatchNormalization()(n) if bn else n + return Add()([m, n]) if res else n + +def level_block(m, dim, depth, inc, acti, do, bn, mp, up, res): + if depth > 0: + n = conv_block(m, dim, acti, bn, res) + m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n) + m = level_block(m, int(inc*dim), depth-1, inc, acti, do, bn, mp, up, res) + if up: + m = UpSampling2D()(m) + m = Conv2D(dim, 2, activation=acti, padding='same')(m) + else: + m = Conv2DTranspose(dim, 3, strides=2, activation=acti, padding='same')(m) + n = Add()([n, m]) + m = conv_block(n, dim, acti, bn, res) + else: + m = conv_block(m, dim, acti, bn, res, do) + return m + +def UNet(img_shape, out_ch=1, start_ch=32, depth=4, inc_rate=1., activation='elu', + dropout=0.5, batchnorm=False, maxpool=True, upconv=True, residual=False): + i = Input(shape=img_shape) + o = level_block(i, start_ch, depth, inc_rate, activation, dropout, batchnorm, maxpool, upconv, residual) + o = Conv2D(out_ch, 1, activation='sigmoid')(o) + return Model(inputs=i, outputs=o) + +def level_block_3d(m, dim, depth, factor, acti, dropout): + if depth > 0: + n = Conv3D(dim, 3, activation=acti, padding='same')(m) + n = Dropout(dropout)(n) if dropout else n + n = Conv3D(dim, 3, activation=acti, padding='same')(n) + m = MaxPooling3D((1,2,2))(n) + m = level_block_3d(m, int(factor*dim), depth-1, factor, acti, dropout) + m = UpSampling3D((1,2,2))(m) + m = Conv3D(dim, 2, activation=acti, padding='same')(m) + m = Concatenate(axis=4)([n, m]) + m = Conv3D(dim, 3, activation=acti, padding='same')(m) + return Conv3D(dim, 3, activation=acti, padding='same')(m) + +def UNet_3D(img_shape, n_out=1, dim=8, depth=3, factor=1.5, acti='elu', dropout=None): + i = Input(shape=img_shape) + o = level_block_3d(i, dim, depth, factor, acti, dropout) + o = Conv3D(n_out, 1, activation='sigmoid')(o) + return Model(inputs=i, outputs=o) + +# Loss Functions + +# 2TP / (2TP + FP + FN) +def f1(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (2. * intersection + 1.) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.) + +def f1_np(y_true, y_pred): + return (2. * (y_true * y_pred).sum() + 1.) / (y_true.sum() + y_pred.sum() + 1.) + +def f1_loss(y_true, y_pred): + return 1-f1(y_true, y_pred) + +def f2(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (5. * intersection + 1.) / (4. * K.sum(y_true_f) + K.sum(y_pred_f) + 1.) + +def f2_loss(y_true, y_pred): + return 1-f2(y_true, y_pred) + +dice = f1 +dice_loss = f1_loss + +def iou(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (intersection + 1.) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1. - intersection) + +def iou_np(y_true, y_pred): + intersection = (y_true * y_pred).sum() + return (intersection + 1.) / (y_true.sum() + y_pred.sum() + 1. - intersection) + +def iou_loss(y_true, y_pred): + return -iou(y_true, y_pred) + +def precision(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (intersection + 1.) / (K.sum(y_pred_f) + 1.) + +def precision_np(y_true, y_pred): + return ((y_true * y_pred).sum() + 1.) / (y_pred.sum() + 1.) + +def recall(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + intersection = K.sum(y_true_f * y_pred_f) + return (intersection + 1.) / (K.sum(y_true_f) + 1.) + +def recall_np(y_true, y_pred): + return ((y_true * y_pred).sum() + 1.) / (y_true.sum() + 1.) + +def mae_img(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + return mae(y_true_f, y_pred_f) + +def bce_img(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + return binary_crossentropy(y_true_f, y_pred_f) + +def f1_bce(y_true, y_pred): + return f1_loss(y_true, y_pred) + bce_img(y_true, y_pred) + +# FP + FN +def error(y_true, y_pred): + y_true_f = K.flatten(y_true) + y_pred_f = K.flatten(y_pred) + return K.sum(K.abs(y_true_f - y_pred_f)) / float(224*224) + +def error_np(y_true, y_pred): + return (abs(y_true - y_pred)).sum() / float(len(y_true.flatten())) + +# Notifications + +def pushover(title, message): + user = "u96ub3t5wu1nexmgi22xjs31jeb8y6" + api = "avfytsyktracxood45myebobtry6yd" + client = Client(user, api_token=api) + client.send_message(message, title=title) + +#from nipype.interfaces.ants import N4BiasFieldCorrection +#correct = N4BiasFieldCorrection() +#correct.inputs.input_image = in_file +#correct.inputs.output_image = out_file +#done = correct.run() +#img done.outputs.output_image \ No newline at end of file