--- a +++ b/utils.py @@ -0,0 +1,117 @@ +import numpy as np +from medpy.filter.binary import largest_connected_component +from skimage.exposure import rescale_intensity +from skimage.transform import resize + + +def dsc(y_pred, y_true, lcc=True): + if lcc and np.any(y_pred): + y_pred = np.round(y_pred).astype(int) + y_true = np.round(y_true).astype(int) + y_pred = largest_connected_component(y_pred) + return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true)) + + +def crop_sample(x): + volume, mask = x + volume[volume < np.max(volume) * 0.1] = 0 + z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1) + z_nonzero = np.nonzero(z_projection) + z_min = np.min(z_nonzero) + z_max = np.max(z_nonzero) + 1 + y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1) + y_nonzero = np.nonzero(y_projection) + y_min = np.min(y_nonzero) + y_max = np.max(y_nonzero) + 1 + x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1) + x_nonzero = np.nonzero(x_projection) + x_min = np.min(x_nonzero) + x_max = np.max(x_nonzero) + 1 + return ( + volume[z_min:z_max, y_min:y_max, x_min:x_max], + mask[z_min:z_max, y_min:y_max, x_min:x_max], + ) + + +def pad_sample(x): + volume, mask = x + a = volume.shape[1] + b = volume.shape[2] + if a == b: + return volume, mask + diff = (max(a, b) - min(a, b)) / 2.0 + if a > b: + padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff)))) + else: + padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0)) + mask = np.pad(mask, padding, mode="constant", constant_values=0) + padding = padding + ((0, 0),) + volume = np.pad(volume, padding, mode="constant", constant_values=0) + return volume, mask + + +def resize_sample(x, size=256): + volume, mask = x + v_shape = volume.shape + out_shape = (v_shape[0], size, size) + mask = resize( + mask, + output_shape=out_shape, + order=0, + mode="constant", + cval=0, + anti_aliasing=False, + ) + out_shape = out_shape + (v_shape[3],) + volume = resize( + volume, + output_shape=out_shape, + order=2, + mode="constant", + cval=0, + anti_aliasing=False, + ) + return volume, mask + + +def normalize_volume(volume): + p10 = np.percentile(volume, 10) + p99 = np.percentile(volume, 99) + volume = rescale_intensity(volume, in_range=(p10, p99)) + m = np.mean(volume, axis=(0, 1, 2)) + s = np.std(volume, axis=(0, 1, 2)) + volume = (volume - m) / s + return volume + + +def log_images(x, y_true, y_pred, channel=1): + images = [] + x_np = x[:, channel].cpu().numpy() + y_true_np = y_true[:, 0].cpu().numpy() + y_pred_np = y_pred[:, 0].cpu().numpy() + for i in range(x_np.shape[0]): + image = gray2rgb(np.squeeze(x_np[i])) + image = outline(image, y_pred_np[i], color=[255, 0, 0]) + image = outline(image, y_true_np[i], color=[0, 255, 0]) + images.append(image) + return images + + +def gray2rgb(image): + w, h = image.shape + image += np.abs(np.min(image)) + image_max = np.abs(np.max(image)) + if image_max > 0: + image /= image_max + ret = np.empty((w, h, 3), dtype=np.uint8) + ret[:, :, 2] = ret[:, :, 1] = ret[:, :, 0] = image * 255 + return ret + + +def outline(image, mask, color): + mask = np.round(mask) + yy, xx = np.nonzero(mask) + for y, x in zip(yy, xx): + if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < 1.0: + image[max(0, y) : y + 1, max(0, x) : x + 1] = color + return image