a b/supporters.py
1
import numpy as np
2
import matplotlib.pyplot as plt
3
from ipywidgets import interact
4
import SimpleITK as sitk
5
import cv2
6
7
def explore_3D_array(arr: np.ndarray, cmap: str = 'gray'):
8
  '''
9
  Given a 3D array with shape (Z,X,Y) This function will create an interactive
10
  widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. 
11
  The purpose of this function to visual inspect the 2D arrays in the image. 
12
13
  Args:
14
    arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
15
    cmap : Which color map use to plot the slices in matplotlib.pyplot
16
  '''
17
18
  def fn(SLICE):
19
    plt.figure(figsize=(7,7))
20
    plt.imshow(arr[SLICE, :, :], cmap=cmap)
21
22
  interact(fn, SLICE=(0, arr.shape[0]-1))
23
  
24
25
def explore_3D_array_axis(arr: np.ndarray, aspect: str = 'axial', cmap: str = 'gray'):
26
  '''
27
  Given a 3D array with shape (Z,X,Y) This function will create an interactive
28
  widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. 
29
  The purpose of this function to visual inspect the 2D arrays in the image. 
30
31
  Args:
32
    arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
33
    aspect : Which aspect to view: sagittal, axial, or coronal
34
    cmap : Which color map use to plot the slices in matplotlib.pyplot
35
  '''
36
37
  def fn(SLICE):
38
    plt.figure(figsize=(7,7))
39
    if aspect == 'sagittal':
40
      plt.imshow(arr[:, SLICE, :], cmap=cmap)
41
    elif aspect == 'axial':
42
      plt.imshow(arr[SLICE, :, :], cmap=cmap)
43
    elif aspect == 'coronal':
44
      plt.imshow(arr[:, :, SLICE], cmap=cmap)
45
    else:
46
      print('Invalid aspect')
47
48
  interact(fn, SLICE=(0, arr.shape[0]-1))
49
50
51
52
def explore_3D_array_comparison(arr_before: np.ndarray, arr_after: np.ndarray, cmap: str = 'gray'):
53
  '''
54
  Given two 3D arrays with shape (Z,X,Y) This function will create an interactive
55
  widget to check out all the 2D arrays with shape (X,Y) inside the 3D arrays.
56
  The purpose of this function to visual compare the 2D arrays after some transformation. 
57
58
  Args:
59
    arr_before : 3D array with shape (Z,X,Y) that represents the volume of a MRI image, before any transform
60
    arr_after : 3D array with shape (Z,X,Y) that represents the volume of a MRI image, after some transform    
61
    cmap : Which color map use to plot the slices in matplotlib.pyplot
62
  '''
63
64
  assert arr_after.shape == arr_before.shape
65
66
  def fn(SLICE):
67
    fig, (ax1, ax2) = plt.subplots(1, 2, sharex='col', sharey='row', figsize=(10,10))
68
69
    ax1.set_title('Label', fontsize=15)
70
    ax1.imshow(arr_before[SLICE, :, :], cmap=cmap)
71
72
    ax2.set_title('Prediction', fontsize=15)
73
    ax2.imshow(arr_after[SLICE, :, :], cmap=cmap)
74
75
    plt.tight_layout()
76
  
77
  interact(fn, SLICE=(0, arr_before.shape[0]-1))
78
79
80
def show_sitk_img_info(img: sitk.Image):
81
  '''
82
  Given a sitk.Image instance prints the information about the MRI image contained.
83
84
  Args:
85
    img : instance of the sitk.Image to check out
86
  '''
87
  pixel_type = img.GetPixelIDTypeAsString()
88
  origin = img.GetOrigin()
89
  dimensions = img.GetSize()
90
  spacing = img.GetSpacing()
91
  direction = img.GetDirection()
92
93
  info = {'Pixel Type' : pixel_type, 'Dimensions': dimensions, 'Spacing': spacing, 'Origin': origin,  'Direction' : direction}
94
  for k,v in info.items():
95
    print(f' {k} : {v}')
96
97
98
def add_suffix_to_filename(filename: str, suffix:str) -> str:
99
  '''
100
  Takes a NIfTI filename and appends a suffix.
101
102
  Args:
103
      filename : NIfTI filename
104
      suffix : suffix to append
105
106
  Returns:
107
      str : filename after append the suffix
108
  '''
109
  if filename.endswith('.nii'):
110
      result = filename.replace('.nii', f'_{suffix}.nii')
111
      return result
112
  elif filename.endswith('.nii.gz'):
113
      result = filename.replace('.nii.gz', f'_{suffix}.nii.gz')
114
      return result
115
  else:
116
      raise RuntimeError('filename with unknown extension')
117
118
119
def rescale_linear(array: np.ndarray, new_min: int, new_max: int):
120
  '''Rescale an array linearly.'''
121
  minimum, maximum = np.min(array), np.max(array)
122
  m = (new_max - new_min) / (maximum - minimum)
123
  b = new_min - m * minimum
124
  return m * array + b
125
126
127
def explore_3D_array_with_mask_contour(arr: np.ndarray, mask: np.ndarray, thickness: int = 1):
128
  '''
129
  Given a 3D array with shape (Z,X,Y) This function will create an interactive
130
  widget to check out all the 2D arrays with shape (X,Y) inside the 3D array. The binary
131
  mask provided will be used to overlay contours of the region of interest over the 
132
  array. The purpose of this function is to visual inspect the region delimited by the mask.
133
134
  Args:
135
    arr : 3D array with shape (Z,X,Y) that represents the volume of a MRI image
136
    mask : binary mask to obtain the region of interest
137
  '''
138
  assert arr.shape == mask.shape
139
  
140
  _arr = rescale_linear(arr,0,1)
141
  _mask = rescale_linear(mask,0,1)
142
  _mask = _mask.astype(np.uint8)
143
144
  def fn(SLICE):
145
    arr_rgb = cv2.cvtColor(_arr[SLICE, :, :], cv2.COLOR_GRAY2RGB)
146
    contours, _ = cv2.findContours(_mask[SLICE, :, :], cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
147
    
148
    arr_with_contours = cv2.drawContours(arr_rgb, contours, -1, (0,1,0), thickness)
149
150
    plt.figure(figsize=(7,7))
151
    plt.imshow(arr_with_contours)
152
153
  interact(fn, SLICE=(0, arr.shape[0]-1))