--- a +++ b/supporters.py @@ -0,0 +1,153 @@ +import numpy as np +import matplotlib.pyplot as plt +from ipywidgets import interact +import SimpleITK as sitk +import cv2 + +def explore_3D_array(arr: np.ndarray, cmap: str = 'gray'): + ''' + Given a 3D array with shape (Z,X,Y) This function will create an interactive + widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. + The purpose of this function to visual inspect the 2D arrays in the image. + + Args: + arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image + cmap : Which color map use to plot the slices in matplotlib.pyplot + ''' + + def fn(SLICE): + plt.figure(figsize=(7,7)) + plt.imshow(arr[SLICE, :, :], cmap=cmap) + + interact(fn, SLICE=(0, arr.shape[0]-1)) + + +def explore_3D_array_axis(arr: np.ndarray, aspect: str = 'axial', cmap: str = 'gray'): + ''' + Given a 3D array with shape (Z,X,Y) This function will create an interactive + widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. + The purpose of this function to visual inspect the 2D arrays in the image. + + Args: + arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image + aspect : Which aspect to view: sagittal, axial, or coronal + cmap : Which color map use to plot the slices in matplotlib.pyplot + ''' + + def fn(SLICE): + plt.figure(figsize=(7,7)) + if aspect == 'sagittal': + plt.imshow(arr[:, SLICE, :], cmap=cmap) + elif aspect == 'axial': + plt.imshow(arr[SLICE, :, :], cmap=cmap) + elif aspect == 'coronal': + plt.imshow(arr[:, :, SLICE], cmap=cmap) + else: + print('Invalid aspect') + + interact(fn, SLICE=(0, arr.shape[0]-1)) + + + +def explore_3D_array_comparison(arr_before: np.ndarray, arr_after: np.ndarray, cmap: str = 'gray'): + ''' + Given two 3D arrays with shape (Z,X,Y) This function will create an interactive + widget to check out all the 2D arrays with shape (X,Y) inside the 3D arrays. + The purpose of this function to visual compare the 2D arrays after some transformation. + + Args: + arr_before : 3D array with shape (Z,X,Y) that represents the volume of a MRI image, before any transform + arr_after : 3D array with shape (Z,X,Y) that represents the volume of a MRI image, after some transform + cmap : Which color map use to plot the slices in matplotlib.pyplot + ''' + + assert arr_after.shape == arr_before.shape + + def fn(SLICE): + fig, (ax1, ax2) = plt.subplots(1, 2, sharex='col', sharey='row', figsize=(10,10)) + + ax1.set_title('Label', fontsize=15) + ax1.imshow(arr_before[SLICE, :, :], cmap=cmap) + + ax2.set_title('Prediction', fontsize=15) + ax2.imshow(arr_after[SLICE, :, :], cmap=cmap) + + plt.tight_layout() + + interact(fn, SLICE=(0, arr_before.shape[0]-1)) + + +def show_sitk_img_info(img: sitk.Image): + ''' + Given a sitk.Image instance prints the information about the MRI image contained. + + Args: + img : instance of the sitk.Image to check out + ''' + pixel_type = img.GetPixelIDTypeAsString() + origin = img.GetOrigin() + dimensions = img.GetSize() + spacing = img.GetSpacing() + direction = img.GetDirection() + + info = {'Pixel Type' : pixel_type, 'Dimensions': dimensions, 'Spacing': spacing, 'Origin': origin, 'Direction' : direction} + for k,v in info.items(): + print(f' {k} : {v}') + + +def add_suffix_to_filename(filename: str, suffix:str) -> str: + ''' + Takes a NIfTI filename and appends a suffix. + + Args: + filename : NIfTI filename + suffix : suffix to append + + Returns: + str : filename after append the suffix + ''' + if filename.endswith('.nii'): + result = filename.replace('.nii', f'_{suffix}.nii') + return result + elif filename.endswith('.nii.gz'): + result = filename.replace('.nii.gz', f'_{suffix}.nii.gz') + return result + else: + raise RuntimeError('filename with unknown extension') + + +def rescale_linear(array: np.ndarray, new_min: int, new_max: int): + '''Rescale an array linearly.''' + minimum, maximum = np.min(array), np.max(array) + m = (new_max - new_min) / (maximum - minimum) + b = new_min - m * minimum + return m * array + b + + +def explore_3D_array_with_mask_contour(arr: np.ndarray, mask: np.ndarray, thickness: int = 1): + ''' + Given a 3D array with shape (Z,X,Y) This function will create an interactive + widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. The binary + mask provided will be used to overlay contours of the region of interest over the + array. The purpose of this function is to visual inspect the region delimited by the mask. + + Args: + arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image + mask : binary mask to obtain the region of interest + ''' + assert arr.shape == mask.shape + + _arr = rescale_linear(arr,0,1) + _mask = rescale_linear(mask,0,1) + _mask = _mask.astype(np.uint8) + + def fn(SLICE): + arr_rgb = cv2.cvtColor(_arr[SLICE, :, :], cv2.COLOR_GRAY2RGB) + contours, _ = cv2.findContours(_mask[SLICE, :, :], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + arr_with_contours = cv2.drawContours(arr_rgb, contours, -1, (0,1,0), thickness) + + plt.figure(figsize=(7,7)) + plt.imshow(arr_with_contours) + + interact(fn, SLICE=(0, arr.shape[0]-1)) \ No newline at end of file