Diff of /lungs/preprocess.py [000000] .. [eac570]

Switch to unified view

a b/lungs/preprocess.py
1
import time
2
from tqdm import tqdm
3
import numpy as np
4
import pydicom as dicom
5
import os
6
from scipy import ndimage
7
import matplotlib.pyplot as plt
8
from pathlib import Path
9
from skimage import measure
10
from collections import defaultdict
11
from sys import argv
12
from random import shuffle
13
14
from lungs.utils import apply_window
15
16
17
# The pixel size/coarseness of the scan differs from scan to scan (e.g. the distance between slices may differ), which can hurt performance of 
18
# CNN approaches. We can deal with this by isomorphic resampling.
19
# Below is code to load a scan, which consists of multiple slices, which we simply save in a Python list. Every folder in the dataset is one 
20
# scan (so one patient). One metadata field is missing, the pixel size in the Z direction, which is the slice thickness. 
21
# Fortunately we can infer this, and we add this to the metadata.
22
23
def is_dcm_file(path):
24
    name, ext = os.path.splitext(path)
25
    # DICOM file extension
26
    if 'dcm' in ext:
27
        return True
28
    # NLST file format
29
    if name.isdigit() and not ext:
30
        return True
31
    return False
32
33
# Load a volume from the given folder path
34
def load_scan(path):
35
    slices = [dicom.read_file(path + '/' + scan) for scan in os.listdir(path) if is_dcm_file(scan)]
36
    slices.sort(key = lambda x: float(x.ImagePositionPatient[2]))
37
38
    if not slices[0].SliceThickness:
39
        try:
40
            slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])
41
        except:
42
            slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
43
            
44
        for s in slices:
45
            s.SliceThickness = slice_thickness  
46
    return slices
47
48
49
# The unit of measurement in CT scans is the **Hounsfield Unit (HU)**, which is a measure of radiodensity. 
50
# CT scanners are carefully calibrated to accurately measure this.  From Wikipedia:
51
# By default however, the returned values are not in this unit. Let's fix this.
52
# Some scanners have cylindrical scanning bounds, but the output volume is square. 
53
# The pixels that fall outside of these bounds get the fixed value -2000. The first step is setting these values to 0, which currently corresponds to air. 
54
# Next, let's go back to HU units, by multiplying with the rescale slope and adding the intercept (which are conveniently stored in the metadata of the scans!).
55
56
def get_pixels_hu(slices):
57
    volume = np.stack([s.pixel_array for s in slices])
58
    # Convert to int16 (from sometimes int16), 
59
    # should be possible as values should always be low enough (<32k)
60
    volume = volume.astype(np.int16)
61
62
    # Set outside-of-scan pixels to 0
63
    # The intercept is usually -1024, so air is approximately 0
64
    volume[volume == -2000] = 0
65
66
    # Convert to Hounsfield units (HU)
67
    for slice_number in range(len(slices)):
68
        
69
        intercept = slices[slice_number].RescaleIntercept
70
        slope = slices[slice_number].RescaleSlope
71
        
72
        if slope != 1:
73
            volume[slice_number] = slope * volume[slice_number].astype(np.float64)
74
            volume[slice_number] = volume[slice_number].astype(np.int16)
75
            
76
        volume[slice_number] += np.int16(intercept)
77
78
    return np.array(volume, dtype=np.int16)
79
80
# # Resampling
81
# A scan may have a pixel spacing of `[2.5, 0.5, 0.5]`, which means that the distance between slices is `2.5` millimeters. 
82
# For a different scan this may be `[1.5, 0.725, 0.725]`, 
83
# this can be problematic for automatic analysis (e.g. using ConvNets).
84
# A common method of dealing with this is resampling the full dataset to a certain isotropic resolution. 
85
# If we choose to resample everything to 1.5mm*1.5mm*1.5mm pixels we can use 3D convnets without worrying about learning zoom/slice thickness invariance. 
86
# Whilst this may seem like a very simple step, it has quite some edge cases due to rounding. Also, it takes quite a while.
87
88
def resample(scan_hu, scan_file, scan, new_spacing, verbose=False):
89
    # Determine current pixel spacing
90
    spacing = np.array([scan_file[0].SliceThickness] + list(scan_file[0].PixelSpacing), dtype=np.float32)
91
    if verbose:
92
        print('Spacing:', spacing)
93
    resize_factor = spacing / new_spacing
94
    new_real_shape = scan_hu.shape * resize_factor
95
    new_shape = np.round(new_real_shape)
96
    real_resize_factor = new_shape / scan_hu.shape
97
    new_spacing = spacing / real_resize_factor
98
    
99
    scan_hu = ndimage.interpolation.zoom(scan_hu, real_resize_factor, mode='nearest')
100
    return scan_hu, new_spacing 
101
102
def largest_label_volume(im, bg=-1):
103
    vals, counts = np.unique(im, return_counts=True)
104
    counts = counts[vals != bg]
105
    vals = vals[vals != bg]
106
    if len(counts) > 0:
107
        return vals[np.argmax(counts)]
108
    else:
109
        return None
110
111
# # Lung segmentation
112
# In order to reduce the problem space, we segment the lungs (and usually some tissue around it).
113
# It consists of a series of applications of region growing and morphological operations. In this case, 
114
# we will use only connected component analysis.
115
# 
116
# The steps:  
117
# * Threshold the volume (-320 HU is a good threshold, but it doesn't matter much for this approach)
118
# * Do connected components, determine label of air around person, fill this with 1s in the binary volume
119
# * Optionally: For every axial slice in the scan, determine the largest solid connected component 
120
# (the body+air around the person), and set others to 0. This fills the structures in the lungs in the mask.
121
# * Keep only the largest air pocket (the human body has other pockets of air here and there).
122
def segment_lung_mask(volume, fill_lung_structures=True):
123
    
124
    # not actually binary, but 1 and 2. 
125
    # 0 is treated as background, which we do not want
126
    binary_volume = np.array(volume > -320, dtype=np.int8)+1
127
    labels = measure.label(binary_volume)
128
    
129
    # Pick the pixel in the very corner to determine which label is air.
130
    #   Improvement: Pick multiple background labels from around the patient
131
    #   More resistant to "trays" on which the patient lays cutting the air 
132
    #   around the person in half
133
    background_label = labels[0,0,0]
134
    
135
    #Fill the air around the person
136
    binary_volume[background_label == labels] = 2
137
    
138
    # Method of filling the lung structures (that is superior to something like 
139
    # morphological closing)
140
    if fill_lung_structures:
141
        # For every slice we determine the largest solid structure
142
        for i, axial_slice in enumerate(binary_volume):
143
            axial_slice = axial_slice - 1
144
            labeling = measure.label(axial_slice)
145
            l_max = largest_label_volume(labeling, bg=0)
146
            if l_max is not None: # This slice contains some lung
147
                binary_volume[i][labeling != l_max] = 1
148
    
149
    binary_volume -= 1 # Make the volume actual binary
150
    binary_volume = 1-binary_volume # Invert it, lungs are now 1
151
    
152
    # Remove other air pockets insided body
153
    labels = measure.label(binary_volume, background=0)
154
    l_max = largest_label_volume(labels, bg=0)
155
    if l_max is not None: # There are air pockets
156
        binary_volume[labels != l_max] = 0
157
    return binary_volume
158
159
def bbox2_3D(volume):
160
    r = np.any(volume, axis=(1, 2))
161
    c = np.any(volume, axis=(0, 2))
162
    z = np.any(volume, axis=(0, 1))
163
    rmin, rmax = np.where(r)[0][[0, -1]]
164
    cmin, cmax = np.where(c)[0][[0, -1]]
165
    zmin, zmax = np.where(z)[0][[0, -1]]
166
    return rmin, rmax, cmin, cmax, zmin, zmax
167
168
def preprocess(scan, errors_map, num_slices=224, crop_size=224, voxel_size=1.5, windowing=False, sample_volume=True, verbose=True):
169
    orig_scan = load_scan(scan)
170
    num_orig_slices = len(orig_scan)
171
    if num_orig_slices < 50:
172
        errors_map['insufficient_slices'] += 1
173
        raise ValueError(scan[-4:] + ': Insufficient muber of slices (<50).')
174
    orig_scan_np = np.stack([s.pixel_array for s in orig_scan]).astype(np.int16)
175
176
    scan_hu = get_pixels_hu(orig_scan)
177
178
    # Let's resample our patient's pixels to an isomorphic resolution
179
    resampled_scan, _ = resample(scan_hu, orig_scan, orig_scan_np, [voxel_size, voxel_size, voxel_size], verbose=verbose)
180
    if verbose:
181
        print("Shape before resampling:", scan_hu.shape)
182
        print("Shape after resampling:", resampled_scan.shape)
183
184
    if resampled_scan.shape[0] < 180:
185
        errors_map['small_z'] += 1
186
        raise ValueError(scan[-4:] + ': Insufficient number of resampled slices (<200).')
187
188
    lung_mask = segment_lung_mask(resampled_scan, True)
189
190
    z_min, z_max, x_min, x_max, y_min, y_max = bbox2_3D(lung_mask)
191
    box_size = (z_max - z_min, x_max - x_min, y_max - y_min)
192
    if verbose:
193
        print('Lung bounding box (min, max):', (z_min, z_max), (x_min, x_max), (y_min, y_max))
194
        print('Bounding box size:', box_size)
195
196
    for dim in box_size:
197
        if dim < 100:
198
            errors_map['seg_error'] += 1
199
            raise ValueError(scan[-4:] + ': Segmentation error.')   
200
201
    lung_center = np.array([z_min + z_max, x_min + x_max, y_min + y_max]) // 2
202
    context = np.array([num_slices, crop_size, crop_size])
203
204
    volume_starts = np.array([max(0, lung_center[i] - context[i] // 2) for i in range(3)])
205
    volume_ends = np.array([min(resampled_scan.shape[i], lung_center[i] + context[i] // 2) for i in range(3)])
206
    volume_size = volume_ends - volume_starts
207
        
208
    starts = context // 2 - volume_size // 2
209
    ends = starts + volume_size
210
211
    lungs_padded = np.zeros((num_slices, crop_size, crop_size))
212
    lungs_padded[starts[0]: ends[0], starts[1]: ends[1], starts[2]: ends[2]] = \
213
            resampled_scan[volume_starts[0]: volume_ends[0], volume_starts[1]: volume_ends[1], volume_starts[2]: volume_ends[2]]
214
215
    if verbose:
216
        print("Final shape", lungs_padded.shape)
217
        
218
    if sample_volume:
219
        # Generate an RGB slice for display
220
        lungs_rgb = np.stack((lungs_padded, lungs_padded, lungs_padded), axis=3)
221
        lungs_sample_slice = lungs_rgb[lungs_rgb.shape[0] // 2]
222
    else:
223
        lungs_sample_slice = None
224
225
    return lungs_padded, lungs_sample_slice
226
    
227
def walk_dicom_dirs(base_in, base_out=None, print_dirs=True):
228
    for root, _, files in os.walk(base_in):
229
        path = root.split(os.sep)
230
        if print_dirs:
231
            print((len(path) - 1) * '---', os.path.basename(root))
232
        # sample_filename = os.path.splitext(files[0])
233
        if len(files) >= 50: # and (sample_filename[0].isdigit() or 'dcm' in sample_filename[1]):
234
            if base_out:
235
                yield root, base_out + os.path.relpath(root, base_in)
236
            else:
237
                yield root
238
239
def walk_np_files(base_in, print_dirs=True):
240
    pathlist = Path(base_in).glob('**/*.np*')
241
    for path in pathlist:
242
        np_path = str(path)
243
        print(np_path)
244
        yield np_path
245
246
def preprocess_all(input_dir, overwrite=False, num_slices=224, crop_size=224, voxel_size=1.5):
247
    start = time.time()
248
    scans = os.listdir(input_dir)
249
    scans.sort()
250
    errors_map = defaultdict(int)
251
    base_out = input_dir.rstrip('/') + '_preprocessed/'
252
    valid_scans = 0
253
254
    scans_num = len(list(walk_dicom_dirs(input_dir, base_out, False)))
255
    for scan_dir_path, out_path in tqdm(walk_dicom_dirs(input_dir, base_out), total=scans_num):
256
        try:
257
            out_dir = os.path.dirname(out_path)
258
            os.makedirs(out_dir, exist_ok=True)
259
260
            if overwrite or not os.path.exists(out_dir) or not os.listdir(out_dir):
261
                preprocessed_scan, scan_rgb_sample = \
262
                    preprocess(scan_dir_path, errors_map, num_slices, crop_size, voxel_size)
263
                plt.imshow(scan_rgb_sample)
264
                plt.savefig(out_path + '.png', bbox_inches='tight')
265
                np.savez_compressed(out_path + '.npz', data=preprocessed_scan)
266
267
            valid_scans += 1
268
            print('\n++++++++++++++++++++++++\nDiagnostics:')
269
            print(errors_map.items())
270
271
        except FileExistsError as e:
272
            valid_scans += 1
273
            print('Exists:', out_path)
274
275
        except ValueError as e:
276
            print('\nERROR!')
277
            print(e)
278
279
    print('Total scans: {}'.format(scans_num))
280
    print('Valid scans: {}'.format(valid_scans))
281
    print('Scans with insufficient slices: {}'.format(errors_map['insufficient_slices']))
282
    print('Scans with bad segmentation: {}'.format(errors_map['bad_seg']))
283
    print('Scans with small resampled z dimension: {}'.format(errors_map['small_z']))
284
    print((time.time() - start) / scans_num, 'sec/volume')
285
286
def split(positives, negatives, lists_dir, print_dirs=False, split_ratio=0.7):
287
    positive_paths = []
288
    negative_paths = []
289
    
290
    for preprocessed_dir, path_list, label in (positives, positive_paths, '1'), (negatives, negative_paths, '0'):
291
        for root, _, files in os.walk(preprocessed_dir):
292
            path = root.split(os.sep)
293
            if print_dirs:
294
                print((len(path) - 1) * '---', os.path.basename(root))
295
            for f in files:
296
                if '.np' in f:
297
                    path_list.append((root + '/' + f, label))
298
        print('\n INFO:', 'volumes with label', label, len(path_list))
299
300
    train_list = []
301
    val_list = []
302
    shuffle(positive_paths)
303
    split_pos = round(split_ratio * len(positive_paths))
304
    shuffle(negative_paths)
305
    split_neg = round(split_ratio * len(negative_paths))
306
    train_list = positive_paths[:split_pos] + negative_paths[:split_neg]
307
    val_list = positive_paths[split_pos:] + negative_paths[split_neg:]
308
    shuffle(train_list)
309
    shuffle(val_list)
310
311
    os.makedirs(lists_dir, exist_ok=True)
312
    with open(lists_dir + '/val.list', 'w') as val_f:
313
        for path, label in val_list:
314
            val_f.write(path + ' ' + label + '\n')
315
316
    with open(lists_dir + '/train.list', 'w') as train_f:
317
        for path, label in train_list:
318
            train_f.write(path + ' ' + label + '\n')