Diff of /supporters.py [000000] .. [70e190]

Switch to side-by-side view

--- a
+++ b/supporters.py
@@ -0,0 +1,153 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from ipywidgets import interact
+import SimpleITK as sitk
+import cv2
+
+def explore_3D_array(arr: np.ndarray, cmap: str = 'gray'):
+  '''
+  Given a 3D array with shape (Z,X,Y) This function will create an interactive
+  widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. 
+  The purpose of this function to visual inspect the 2D arrays in the image. 
+
+  Args:
+    arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
+    cmap : Which color map use to plot the slices in matplotlib.pyplot
+  '''
+
+  def fn(SLICE):
+    plt.figure(figsize=(7,7))
+    plt.imshow(arr[SLICE, :, :], cmap=cmap)
+
+  interact(fn, SLICE=(0, arr.shape[0]-1))
+  
+
+def explore_3D_array_axis(arr: np.ndarray, aspect: str = 'axial', cmap: str = 'gray'):
+  '''
+  Given a 3D array with shape (Z,X,Y) This function will create an interactive
+  widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. 
+  The purpose of this function to visual inspect the 2D arrays in the image. 
+
+  Args:
+    arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
+    aspect : Which aspect to view: sagittal, axial, or coronal
+    cmap : Which color map use to plot the slices in matplotlib.pyplot
+  '''
+
+  def fn(SLICE):
+    plt.figure(figsize=(7,7))
+    if aspect == 'sagittal':
+      plt.imshow(arr[:, SLICE, :], cmap=cmap)
+    elif aspect == 'axial':
+      plt.imshow(arr[SLICE, :, :], cmap=cmap)
+    elif aspect == 'coronal':
+      plt.imshow(arr[:, :, SLICE], cmap=cmap)
+    else:
+      print('Invalid aspect')
+
+  interact(fn, SLICE=(0, arr.shape[0]-1))
+
+
+
+def explore_3D_array_comparison(arr_before: np.ndarray, arr_after: np.ndarray, cmap: str = 'gray'):
+  '''
+  Given two 3D arrays with shape (Z,X,Y) This function will create an interactive
+  widget to check out all the 2D arrays with shape (X,Y) inside the 3D arrays.
+  The purpose of this function to visual compare the 2D arrays after some transformation. 
+
+  Args:
+    arr_before : 3D array with shape (Z,X,Y) that represents the volume of a MRI image, before any transform
+    arr_after : 3D array with shape (Z,X,Y) that represents the volume of a MRI image, after some transform    
+    cmap : Which color map use to plot the slices in matplotlib.pyplot
+  '''
+
+  assert arr_after.shape == arr_before.shape
+
+  def fn(SLICE):
+    fig, (ax1, ax2) = plt.subplots(1, 2, sharex='col', sharey='row', figsize=(10,10))
+
+    ax1.set_title('Label', fontsize=15)
+    ax1.imshow(arr_before[SLICE, :, :], cmap=cmap)
+
+    ax2.set_title('Prediction', fontsize=15)
+    ax2.imshow(arr_after[SLICE, :, :], cmap=cmap)
+
+    plt.tight_layout()
+  
+  interact(fn, SLICE=(0, arr_before.shape[0]-1))
+
+
+def show_sitk_img_info(img: sitk.Image):
+  '''
+  Given a sitk.Image instance prints the information about the MRI image contained.
+
+  Args:
+    img : instance of the sitk.Image to check out
+  '''
+  pixel_type = img.GetPixelIDTypeAsString()
+  origin = img.GetOrigin()
+  dimensions = img.GetSize()
+  spacing = img.GetSpacing()
+  direction = img.GetDirection()
+
+  info = {'Pixel Type' : pixel_type, 'Dimensions': dimensions, 'Spacing': spacing, 'Origin': origin,  'Direction' : direction}
+  for k,v in info.items():
+    print(f' {k} : {v}')
+
+
+def add_suffix_to_filename(filename: str, suffix:str) -> str:
+  '''
+  Takes a NIfTI filename and appends a suffix.
+
+  Args:
+      filename : NIfTI filename
+      suffix : suffix to append
+
+  Returns:
+      str : filename after append the suffix
+  '''
+  if filename.endswith('.nii'):
+      result = filename.replace('.nii', f'_{suffix}.nii')
+      return result
+  elif filename.endswith('.nii.gz'):
+      result = filename.replace('.nii.gz', f'_{suffix}.nii.gz')
+      return result
+  else:
+      raise RuntimeError('filename with unknown extension')
+
+
+def rescale_linear(array: np.ndarray, new_min: int, new_max: int):
+  '''Rescale an array linearly.'''
+  minimum, maximum = np.min(array), np.max(array)
+  m = (new_max - new_min) / (maximum - minimum)
+  b = new_min - m * minimum
+  return m * array + b
+
+
+def explore_3D_array_with_mask_contour(arr: np.ndarray, mask: np.ndarray, thickness: int = 1):
+  '''
+  Given a 3D array with shape (Z,X,Y) This function will create an interactive
+  widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. The binary
+  mask provided will be used to overlay contours of the region of interest over the 
+  array. The purpose of this function is to visual inspect the region delimited by the mask.
+
+  Args:
+    arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
+    mask : binary mask to obtain the region of interest
+  '''
+  assert arr.shape == mask.shape
+  
+  _arr = rescale_linear(arr,0,1)
+  _mask = rescale_linear(mask,0,1)
+  _mask = _mask.astype(np.uint8)
+
+  def fn(SLICE):
+    arr_rgb = cv2.cvtColor(_arr[SLICE, :, :], cv2.COLOR_GRAY2RGB)
+    contours, _ = cv2.findContours(_mask[SLICE, :, :], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
+    
+    arr_with_contours = cv2.drawContours(arr_rgb, contours, -1, (0,1,0), thickness)
+
+    plt.figure(figsize=(7,7))
+    plt.imshow(arr_with_contours)
+
+  interact(fn, SLICE=(0, arr.shape[0]-1))
\ No newline at end of file