Switch to unified view

a b/1 - Methods with Improved Results/SegmentationFunctions.py
1
import os
2
import cv2 
3
import math
4
5
import numpy as np
6
import pandas as pd
7
import matplotlib.pyplot as plt
8
from matplotlib.patches import Rectangle
9
10
from sklearn.cluster import KMeans
11
from skimage.morphology import erosion, opening, closing, square, \
12
                               disk, convex_hull_image, remove_small_holes
13
from skimage.measure import label, regionprops
14
15
from skimage.filters import sobel
16
from skimage.feature import canny
17
from scipy import ndimage as ndi
18
from skimage.segmentation import watershed
19
    
20
SMALL_FONT = 13
21
MEDIUM_FONT = 15
22
LARGE_FONT = 18
23
24
plt.rc('font', size=SMALL_FONT)          # controls default text sizes
25
plt.rc('axes', titlesize=SMALL_FONT)     # fontsize of the axes title
26
plt.rc('axes', labelsize=MEDIUM_FONT)    # fontsize of the x and y labels
27
plt.rc('xtick', labelsize=SMALL_FONT)    # fontsize of the tick labels
28
plt.rc('ytick', labelsize=SMALL_FONT)    # fontsize of the tick labels
29
plt.rc('legend', fontsize=MEDIUM_FONT)   # legend fontsize
30
plt.rc('figure', titlesize=LARGE_FONT)   # fontsize of the figure title
31
32
plt.rcParams["figure.figsize"] = (10, 5)
33
34
def readSortedSlices(path):
35
    
36
    slices = []
37
    for s in os.listdir(path):
38
        slices.append(path + '/' + s)       
39
    slices.sort(key = lambda s: int(s[s.find('_') + 1 : s.find('.')]))
40
    ID = slices[0][slices[0].find('/') + 1 : slices[0].find('_')]
41
    print('CT scan of Patient %s consists of %d slices.' % (ID, len(slices)))  
42
    return (slices, ID)
43
44
def getSliceImages(slices):
45
    
46
    return list(map(readImg, slices))
47
48
def readImg(path, showOutput=0):
49
    
50
    img = cv2.imread(path)
51
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
52
    
53
    if showOutput:
54
        plt.title('A CT Scan Image Slice')
55
        plt.imshow(img, cmap='gray')
56
    return img
57
    
58
def imgKMeans(img, K, showOutput=0, showHistogram=0):
59
    '''
60
    Apply KMeans on an image with the number of clusters K
61
    Input: Image, Number of clusters K
62
    Output: Dictionary of cluster center labels and points, Output segmented image
63
    '''
64
65
    imgflat = np.reshape(img, img.shape[0] * img.shape[1]).reshape(-1, 1)
66
        
67
    kmeans = KMeans(n_clusters=K, verbose=0)
68
    
69
    kmmodel = kmeans.fit(imgflat)
70
    
71
    labels = kmmodel.labels_
72
    centers = kmmodel.cluster_centers_
73
    center_labels = dict(zip(np.arange(K), centers))
74
    
75
    output = np.array([center_labels[label] for label in labels])
76
    output = output.reshape(img.shape[0], img.shape[1]).astype(int)
77
    
78
    if showOutput:
79
80
        fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 5))
81
        axes = axes.ravel()
82
83
        axes[0].imshow(img, cmap='gray')
84
        axes[0].set_title('Original Image')
85
86
        axes[1].imshow(output)
87
        axes[1].set_title('Image after KMeans (K = ' + str(K) + ')')
88
    
89
    return center_labels, output
90
91
def preprocessImage(img, showOutput=0):
92
    '''
93
    Preprocess the image by applying truncated thresholding using KMeans
94
    Input: Image
95
    Output: Preprocessed image
96
    '''
97
    centroids, segmented_img = imgKMeans(img, 3, showOutput=0)
98
    
99
    sorted_center_values = sorted([i[0] for i in centroids.values()])
100
    threshold = (sorted_center_values[-1] + sorted_center_values[-2]) / 2
101
    
102
    retval, procImg = cv2.threshold(img, threshold, 255, cv2.THRESH_TOZERO) 
103
    
104
    if showOutput:
105
        
106
        fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10, 5))
107
        axes = axes.ravel()
108
109
        axes[0].imshow(img, cmap='gray')
110
        axes[0].set_title('Original Image')
111
112
        axes[1].imshow(procImg, cmap='gray')
113
        axes[1].set_title('Processed Image - After Thresholding')
114
    
115
    return procImg, threshold
116
117
118
def getForegroundMask(img, fg_threshold, showOutput=0):
119
    
120
    retval, init_fg_mask = cv2.threshold(img, fg_threshold, 255, cv2.THRESH_BINARY)
121
122
    # Morphological operations to clean the mask
123
    fg_mask_opened = opening(init_fg_mask, square(3))
124
    fg_mask_opened2 = opening(fg_mask_opened, disk(4))
125
    
126
    # Perform edge-based segmentation of the foreground...
127
    
128
    # Detect contours that delineate the foreground with the Canny edge detector
129
    edges = canny(fg_mask_opened2) # Background is uniform - edges are on the boundary/inside ROI
130
    
131
    # Fill the inner part of the boundary using morphology ops
132
    fg_mask = ndi.binary_fill_holes(fg_mask_opened2)
133
    
134
    # Plot all steps
135
    if showOutput:
136
        fig, axes = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(16, 10))
137
        axes = axes.ravel()
138
139
        axes[0].set_title('Original Image')
140
        axes[0].imshow(img, cmap='gray')
141
        axes[1].set_title('On Performing Thresholding\'s')
142
        axes[1].imshow(init_fg_mask, cmap='gray')
143
        axes[2].set_title('On Opening with Square SE (3)')
144
        axes[2].imshow(fg_mask_opened, cmap='gray')
145
        axes[3].set_title('On Opening with Disk SE (4)')
146
        axes[3].imshow(fg_mask_opened2, cmap='gray')
147
        axes[4].set_title('Outer Boundary Delineation with Canny\'s')
148
        axes[4].imshow(edges, cmap='gray')
149
        axes[5].set_title('Foreground Mask')
150
        axes[5].imshow(fg_mask, cmap='gray')
151
152
    return fg_mask
153
154
155
def getLungTracheaMasks(img, fg_mask, fg_threshold, showOutput=0):
156
    
157
    # Distinguish black pixels of the background from those of the lungs
158
    enhanced = img.copy()
159
    
160
    for i in range(img.shape[0]):
161
        for j in range(img.shape[1]):
162
            
163
            if fg_mask[i][j] == 0:
164
                enhanced[i][j] = 255
165
         
166
    # Extract lungs from the foreground mask
167
    retval, initial_lung_mask = cv2.threshold(enhanced, fg_threshold, 255, cv2.THRESH_BINARY_INV) 
168
    
169
    # Clean up the lung mask with morphological operations
170
    lung_mask_op = opening(initial_lung_mask, square(2))
171
    lung_mask_opcl = closing(lung_mask_op, disk(6))
172
    lung_mask_opclrm = ndi.binary_fill_holes(lung_mask_opcl)
173
174
    # Get connected components of the segmented image and label them
175
    label_img = label(lung_mask_opclrm)
176
    lung_regions = regionprops(label_img)
177
    
178
    # Upon experimentation: areas of regions < 1500 are wind pipe structures
179
    trachea_labels = []
180
    for i in lung_regions:
181
        if i.area < 1500:
182
            trachea_labels.append(i.label)
183
            
184
    # Create trachea mask as a summation of all those regions
185
    trachea_mask = np.zeros(img.shape, dtype=np.uint8)
186
    for row in range(label_img.shape[0]):
187
        for col in range(label_img.shape[1]):
188
            if label_img[row][col] in trachea_labels:
189
                trachea_mask[row][col] = 255
190
            
191
    # Lung mask is made of all the other regions
192
    lung_mask = lung_mask_opclrm * np.invert(trachea_mask) 
193
    
194
    # Lung mask is all black? Convex hull set to 0, since convex hull op on empty img errors out
195
    if sum(sum(lung_mask)) > 0:
196
        ch_lung_mask = convex_hull_image(lung_mask)
197
    else:
198
        ch_lung_mask = lung_mask.copy()
199
    
200
    
201
    initial_int_heart_mask = ch_lung_mask * np.invert(lung_mask) * np.invert(trachea_mask)
202
    
203
    int_heart_mask_op1 = opening(initial_int_heart_mask, square(5))
204
    int_heart_mask_op2 = opening(int_heart_mask_op1, disk(4))
205
    
206
    heart_label_img = label(int_heart_mask_op2)
207
    heart_regions = regionprops(heart_label_img)
208
    
209
    areas = {}
210
    for i in heart_regions:
211
        areas[i.label] = i.area
212
    
213
    if areas:
214
        heart_label = max(areas, key=areas.get)
215
        int_heart_mask = np.where(heart_label_img==heart_label, np.uint8(255), np.uint8(0))
216
    else:
217
        int_heart_mask = np.zeros(img.shape, dtype=np.uint8)
218
        
219
    if showOutput:
220
        
221
        fig, axes = plt.subplots(4, 3, sharex=True, sharey=True, figsize=(20, 20))
222
        axes = axes.ravel()
223
224
        axes[0].set_title('Original Image')
225
        axes[0].imshow(img, cmap='gray')
226
227
        axes[1].set_title('Enhanced Image')
228
        axes[1].imshow(enhanced, cmap='gray')
229
230
        axes[2].set_title('Initial Lung Mask')
231
        axes[2].imshow(initial_lung_mask, cmap='gray')
232
233
        axes[3].set_title('On Opening with Square SE (2)')
234
        axes[3].imshow(lung_mask_op, cmap='gray')
235
236
        axes[4].set_title('On Closing with Disk SE (6)')
237
        axes[4].imshow(lung_mask_opcl, cmap='gray')
238
239
        axes[5].set_title('On Filling Regions')
240
        axes[5].imshow(lung_mask_opclrm, cmap='gray')
241
242
        axes[6].set_title('Trachea Mask/Primary Bronchi')
243
        axes[6].imshow(trachea_mask, cmap='gray')
244
245
        axes[7].set_title('Lung Mask')
246
        axes[7].imshow(lung_mask, cmap='gray')
247
248
        axes[8].set_title('Convex Hull of Lung Mask')
249
        axes[8].imshow(ch_lung_mask, cmap='gray')
250
251
        axes[9].set_title('Initial Intermediate Heart Mask')
252
        axes[9].imshow(initial_int_heart_mask, cmap='gray')
253
254
        axes[10].set_title('On Opening with Square SE (3)')
255
        axes[10].imshow(int_heart_mask_op1, cmap='gray')
256
257
        axes[11].set_title('Intermediate heart mask')
258
        axes[11].imshow(int_heart_mask, cmap='gray')
259
    
260
    return trachea_mask, lung_mask, ch_lung_mask, int_heart_mask
261
262
    
263
def chullSpineMask(img, int_heart_mask, showOutput=0):
264
    
265
    # If no heart
266
    if not int_heart_mask.any():
267
        return int_heart_mask, int_heart_mask
268
    
269
    int_heart_pixel = img.copy()
270
    
271
    for i in range(img.shape[0]):
272
        for j in range(img.shape[1]):
273
            if int_heart_mask[i][j] == 0:
274
                int_heart_pixel[i][j] = 0
275
                
276
    centroids, segmented_heart_img = imgKMeans(int_heart_pixel, 3, showOutput=0)
277
    
278
    spine_threshold = (max(centroids.values()))[0]
279
280
    retval, initial_spine_mask = cv2.threshold(int_heart_pixel, spine_threshold, 255, cv2.THRESH_BINARY) 
281
    
282
    bone_mask = closing(initial_spine_mask, disk(20))  
283
    
284
    label_spine = label(bone_mask)
285
    spine_regions = regionprops(label_spine)
286
287
    # Assumption: The spine's area is greater than that of any calcium deposits
288
    labels = []
289
    areas = {}
290
    geometric_measures = {}
291
    for i in spine_regions:
292
        labels.append(i.label)
293
        areas[i.label] = i.area
294
        geometric_measures[i.label] = [i.centroid, i.orientation, i.axis_major_length]
295
    
296
    spine_label = max(areas, key=areas.get)
297
    labels.remove(spine_label)
298
    spine_mask = np.where(label_spine==spine_label, np.uint8(255), np.uint8(0))
299
    
300
#     if labels:
301
#         calcium_deposit_mask = np.where(label_spine==labels[0], np.uint8(255), np.uint8(0))
302
#     else:
303
#         calcium_deposit_mask = np.zeros(img.shape, dtype=np.uint8)
304
    
305
    label_heart = label(int_heart_mask)
306
    heart_regions = regionprops(label_heart)
307
    heart_region_area = heart_regions[0].area
308
    spine_region_area = areas[spine_label] 
309
    
310
    frac_heart = (heart_region_area - spine_region_area)/heart_region_area
311
    
312
    if frac_heart < 0.5:
313
        heart_mask = np.zeros(img.shape, dtype=np.uint8)
314
        make_spine_mask = 0
315
    else:
316
        make_spine_mask = 1
317
        
318
        # Center point of the spine - get the centroid 
319
        y0, x0 = geometric_measures[spine_label][0]
320
321
        orientation = geometric_measures[spine_label][1]
322
323
        # Top-most point of the spine
324
        # top_most_coordinate = centroid - (slightly_more_than_half * axis_major_length * sin(angle))
325
        x2 = x0 - math.sin(orientation) * 0.6 * geometric_measures[spine_label][2]
326
        y2 = y0 - math.cos(orientation) * 0.6 * geometric_measures[spine_label][2]
327
328
        chull_spine_mask = spine_mask.copy()
329
330
        # Vertical axis
331
        for i in range(math.ceil(y2), img.shape[1]):
332
333
            if i > math.ceil(y0):
334
                # Horizontal axis
335
                for j in range(img.shape[0]):
336
                    chull_spine_mask[i][j] = 255
337
            else:
338
                # Horizontal axis
339
                for j in range(math.ceil(x0)):
340
                    chull_spine_mask[i][j] = 255
341
342
343
        heart_mask = int_heart_mask * np.invert(chull_spine_mask)
344
    
345
    if showOutput:
346
        
347
        fig, axes = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(15, 15))
348
        axes = axes.ravel()
349
350
        axes[0].set_title('Intermediate Heart Mask')
351
        axes[0].imshow(int_heart_mask, cmap='gray')
352
353
        axes[1].set_title('Intermediate Heart Segment')
354
        axes[1].imshow(int_heart_pixel, cmap='gray')
355
        
356
        axes[2].set_title('Intermediate Heart Segment on K-Means (K = 3)')
357
        axes[2].imshow(segmented_heart_img)
358
359
        axes[3].set_title('Spine Mask')
360
        axes[3].imshow(initial_spine_mask, cmap='gray')
361
        
362
        axes[4].set_title('On Closing with Disk SE (20)')
363
        axes[4].imshow(spine_mask, cmap='gray')
364
        
365
        axes[5].set_title('On Opening with Square SE (4)')
366
        axes[5].imshow(spine_mask, cmap='gray')
367
        
368
        axes[6].set_title('Centroid and uppermost point')
369
        axes[6].imshow(spine_mask, cmap='gray')
370
        
371
        if make_spine_mask:
372
            axes[6].plot((x0, x2), (y0, y2), '-r', linewidth=1.5)
373
            axes[6].plot(x0, y0, '.g', markersize=5)
374
            axes[6].plot(x2, y2, '.b', markersize=5)
375
        
376
            axes[7].set_title('Convex Hull of Spine Mask')
377
            axes[7].imshow(chull_spine_mask, cmap='gray')
378
        else:
379
            axes[7].set_title('Convex Hull of Spine Mask')
380
            axes[7].imshow(chull_spine_mask, cmap='gray')
381
        
382
        axes[8].set_title('Heart Mask')
383
        axes[8].imshow(heart_mask, cmap='gray')
384
        
385
    return spine_mask, heart_mask
386
387
def segmentHeart(img, heart_mask, showOutput=0):
388
389
    seg_heart = cv2.bitwise_and(img, img, mask=heart_mask)
390
    
391
    if showOutput:
392
        plt.figure(figsize=(10, 5))
393
        plt.title('Segmented Heart')
394
        plt.imshow(seg_heart, cmap='gray')
395
    return seg_heart
396
    
397
def segmentHeartLungsTrachea(img, heart_mask, lung_mask, trachea_mask, showOutput=0):
398
399
    seg_heart = cv2.bitwise_and(img, img, mask=heart_mask)
400
    seg_lungs = cv2.bitwise_and(img, img, mask=lung_mask)
401
    seg_trachea = cv2.bitwise_and(img, img, mask=trachea_mask)
402
    
403
    if showOutput:
404
        fig, [ax1, ax2, ax3] = plt.subplots(1, 3, figsize=(12, 6), sharex=True, sharey=False)
405
        
406
        ax1.set_title('Segmented Heart')
407
        ax1.imshow(seg_heart, cmap='gray')
408
        
409
        ax2.set_title('Segmented Lungs')
410
        ax2.imshow(seg_lungs, cmap='gray')
411
        
412
        ax3.set_title('Segmented Trachea')
413
        ax3.imshow(seg_trachea, cmap='gray')
414
        
415
    return seg_heart, seg_lungs, seg_trachea
416
    
417
418
def applyMaskColor(mask, mask_color):
419
    
420
    masked = np.concatenate(([mask[ ... , np.newaxis] * color for color in mask_color]), axis=2)
421
    
422
    # Matplotlib expects color intensities to range from 0 to 1 if a float
423
    maxValue = np.amax(masked)
424
    minValue = np.amin(masked)
425
426
    # Therefore, scale the color image accordingly
427
    if maxValue - minValue == 0:
428
        return masked
429
    else:
430
        masked = masked / (maxValue - minValue)
431
    
432
    return masked
433
434
def getColoredMasks(img, heart_mask, lung_mask, trachea_mask, showOutput=0):
435
    heart_mask_color = np.array([256, 0, 0])
436
    lung_mask_color = np.array([0, 256, 0])
437
    trachea_mask_color = np.array([0, 0, 256])
438
439
    heart_colored = applyMaskColor(heart_mask, heart_mask_color)
440
    lung_colored = applyMaskColor(lung_mask, lung_mask_color)
441
    trachea_colored = applyMaskColor(trachea_mask, trachea_mask_color)
442
    
443
    colored_masks = heart_colored + lung_colored + trachea_colored
444
445
    if showOutput:
446
        fig, axes = plt.subplots(2, 2, figsize=(10, 10))
447
        ax = axes.ravel()
448
        
449
        ax[0].set_title("Original Image")
450
        ax[0].imshow(img, cmap='gray')
451
        ax[1].set_title("Heart Mask")
452
        ax[1].imshow(heart_colored)
453
        ax[2].set_title("Lung Mask")
454
        ax[2].imshow(lung_colored)
455
        ax[3].set_title("Masks")
456
        ax[3].imshow(colored_masks)
457
    
458
    return heart_colored, lung_colored, trachea_colored, colored_masks